You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/23 02:10:52 UTC

[spark] branch branch-3.3 updated: [SPARK-32268][SQL] Row-level Runtime Filtering

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 2b59a96  [SPARK-32268][SQL] Row-level Runtime Filtering
2b59a96 is described below

commit 2b59a9616c5922d515db430752044869abd93746
Author: Abhishek Somani <ab...@databricks.com>
AuthorDate: Wed Mar 23 09:57:28 2022 +0800

    [SPARK-32268][SQL] Row-level Runtime Filtering
    
    ### What changes were proposed in this pull request?
    
    This PR proposes row-level runtime filters in Spark to reduce intermediate data volume for operators like shuffle, join and aggregate, and hence improve performance. We propose two mechanisms to do this: semi-join filters or bloom filters, and both mechanisms are proposed to co-exist side-by-side behind feature configs.
    [Design Doc](https://docs.google.com/document/d/16IEuyLeQlubQkH8YuVuXWKo2-grVIoDJqQpHZrE7q04/edit?usp=sharing) with more details.
    
    ### Why are the changes needed?
    
    With Semi-Join, we see 9 queries improve for the TPC DS 3TB benchmark, and no regressions.
    With Bloom Filter, we see 10 queries improve for the TPC DS 3TB benchmark, and no regressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added tests
    
    Closes #35789 from somani/rf.
    
    Lead-authored-by: Abhishek Somani <ab...@databricks.com>
    Co-authored-by: Abhishek Somani <ab...@gmail.com>
    Co-authored-by: Yuming Wang <yu...@ebay.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 1f4e4c812a9dc6d7e35631c1663c1ba6f6d9b721)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/util/sketch/BloomFilter.java  |   7 +
 .../apache/spark/util/sketch/BloomFilterImpl.java  |   5 +
 .../expressions/BloomFilterMightContain.scala      | 113 +++++
 .../aggregate/BloomFilterAggregate.scala           | 179 ++++++++
 .../sql/catalyst/expressions/objects/objects.scala |   2 +
 .../sql/catalyst/expressions/predicates.scala      |  16 +
 .../catalyst/expressions/regexpExpressions.scala   |   5 +-
 .../catalyst/optimizer/InjectRuntimeFilter.scala   | 303 +++++++++++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   3 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  80 ++++
 .../spark/sql/execution/SparkOptimizer.scala       |   2 +
 .../dynamicpruning/PartitionPruning.scala          |  15 -
 .../spark/sql/BloomFilterAggregateQuerySuite.scala | 215 +++++++++
 .../spark/sql/InjectRuntimeFilterSuite.scala       | 503 +++++++++++++++++++++
 14 files changed, 1432 insertions(+), 16 deletions(-)

diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index c53987e..2a6e270 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -164,6 +164,13 @@ public abstract class BloomFilter {
   public abstract void writeTo(OutputStream out) throws IOException;
 
   /**
+   * @return the number of set bits in this {@link BloomFilter}.
+   */
+  public long cardinality() {
+    throw new UnsupportedOperationException("Not implemented");
+  }
+
+  /**
    * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close
    * the stream.
    */
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index e7766ee..ccf1833 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -207,6 +207,11 @@ class BloomFilterImpl extends BloomFilter implements Serializable {
     return this;
   }
 
+  @Override
+  public long cardinality() {
+    return this.bits.cardinality();
+  }
+
   private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
           throws IncompatibleMergeException {
     // Duplicates the logic of `isCompatible` here to provide better error message.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
new file mode 100644
index 0000000..cf052f8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.io.ByteArrayInputStream
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
+import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.BloomFilter
+
+/**
+ * An internal scalar function that returns the membership check result (either true or false)
+ * for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`.
+ * Not that since the function is "might contain", always returning true regardless is not
+ * wrong.
+ * Note that this expression requires that `bloomFilterExpression` is either a constant value or
+ * an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite.
+ *
+ * @param bloomFilterExpression the Binary data of Bloom filter.
+ * @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`.
+ */
+case class BloomFilterMightContain(
+    bloomFilterExpression: Expression,
+    valueExpression: Expression) extends BinaryExpression {
+
+  override def nullable: Boolean = true
+  override def left: Expression = bloomFilterExpression
+  override def right: Expression = valueExpression
+  override def prettyName: String = "might_contain"
+  override def dataType: DataType = BooleanType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
+           (BinaryType, LongType) =>
+        bloomFilterExpression match {
+          case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
+          case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
+            TypeCheckResult.TypeCheckSuccess
+          case _ =>
+            TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
+              "should be either a constant value or a scalar subquery expression")
+        }
+      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+        s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " +
+        s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+    }
+  }
+
+  override protected def withNewChildrenInternal(
+      newBloomFilterExpression: Expression,
+      newValueExpression: Expression): BloomFilterMightContain =
+    copy(bloomFilterExpression = newBloomFilterExpression,
+      valueExpression = newValueExpression)
+
+  // The bloom filter created from `bloomFilterExpression`.
+  @transient private lazy val bloomFilter = {
+    val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]]
+    if (bytes == null) null else deserialize(bytes)
+  }
+
+  override def eval(input: InternalRow): Any = {
+    if (bloomFilter == null) {
+      null
+    } else {
+      val value = valueExpression.eval(input)
+      if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long])
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    if (bloomFilter == null) {
+      ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType))
+    } else {
+      val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName)
+      val valueEval = valueExpression.genCode(ctx)
+      ev.copy(code = code"""
+      ${valueEval.code}
+      boolean ${ev.isNull} = ${valueEval.isNull};
+      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.value} = $bf.mightContainLong((Long)${valueEval.value});
+      }""")
+    }
+  }
+
+  final def deserialize(bytes: Array[Byte]): BloomFilter = {
+    val in = new ByteArrayInputStream(bytes)
+    val bloomFilter = BloomFilter.readFrom(in)
+    in.close()
+    bloomFilter
+  }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
new file mode 100644
index 0000000..c734bca
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.BloomFilter
+
+/**
+ * An internal aggregate function that creates a Bloom filter from input values.
+ *
+ * @param child                     Child expression of Long values for creating a Bloom filter.
+ * @param estimatedNumItemsExpression The number of estimated distinct items (optional).
+ * @param numBitsExpression         The number of bits to use (optional).
+ */
+case class BloomFilterAggregate(
+    child: Expression,
+    estimatedNumItemsExpression: Expression,
+    numBitsExpression: Expression,
+    override val mutableAggBufferOffset: Int,
+    override val inputAggBufferOffset: Int)
+  extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] {
+
+  def this(child: Expression, estimatedNumItemsExpression: Expression,
+      numBitsExpression: Expression) = {
+    this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0)
+  }
+
+  def this(child: Expression, estimatedNumItemsExpression: Expression) = {
+    this(child, estimatedNumItemsExpression,
+      // 1 byte per item.
+      Multiply(estimatedNumItemsExpression, Literal(8L)))
+  }
+
+  def this(child: Expression) = {
+    this(child, Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS)),
+      Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_NUM_BITS)))
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (first.dataType, second.dataType, third.dataType) match {
+      case (_, NullType, _) | (_, _, NullType) =>
+        TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments")
+      case (LongType, LongType, LongType) =>
+        if (!estimatedNumItemsExpression.foldable) {
+          TypeCheckFailure("The estimated number of items provided must be a constant literal")
+        } else if (estimatedNumItems <= 0L) {
+          TypeCheckFailure("The estimated number of items must be a positive value " +
+            s" (current value = $estimatedNumItems)")
+        } else if (!numBitsExpression.foldable) {
+          TypeCheckFailure("The number of bits provided must be a constant literal")
+        } else if (numBits <= 0L) {
+          TypeCheckFailure("The number of bits must be a positive value " +
+            s" (current value = $numBits)")
+        } else {
+          require(estimatedNumItems <=
+            SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
+          require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
+          TypeCheckSuccess
+        }
+      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+        s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " +
+        s"arguments, but it's [${first.dataType.catalogString}, " +
+        s"${second.dataType.catalogString}, ${third.dataType.catalogString}]")
+    }
+  }
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = BinaryType
+
+  override def prettyName: String = "bloom_filter_agg"
+
+  // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation.
+  private lazy val estimatedNumItems: Long =
+    Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
+      SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
+
+  // Mark as lazy so that `numBits` is not evaluated during tree transformation.
+  private lazy val numBits: Long =
+    Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
+      SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
+
+  override def first: Expression = child
+
+  override def second: Expression = estimatedNumItemsExpression
+
+  override def third: Expression = numBitsExpression
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newEstimatedNumItemsExpression: Expression,
+      newNumBitsExpression: Expression): BloomFilterAggregate = {
+    copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression,
+      numBitsExpression = newNumBitsExpression)
+  }
+
+  override def createAggregationBuffer(): BloomFilter = {
+    BloomFilter.create(estimatedNumItems, numBits)
+  }
+
+  override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = {
+    val value = child.eval(inputRow)
+    // Ignore null values.
+    if (value == null) {
+      return buffer
+    }
+    buffer.putLong(value.asInstanceOf[Long])
+    buffer
+  }
+
+  override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = {
+    buffer.mergeInPlace(other)
+  }
+
+  override def eval(buffer: BloomFilter): Any = {
+    if (buffer.cardinality() == 0) {
+      // There's no set bit in the Bloom filter and hence no not-null value is processed.
+      return null
+    }
+    serialize(buffer)
+  }
+
+  override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate =
+    copy(mutableAggBufferOffset = newOffset)
+
+  override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate =
+    copy(inputAggBufferOffset = newOffset)
+
+  override def serialize(obj: BloomFilter): Array[Byte] = {
+    BloomFilterAggregate.serialize(obj)
+  }
+
+  override def deserialize(bytes: Array[Byte]): BloomFilter = {
+    BloomFilterAggregate.deserialize(bytes)
+  }
+}
+
+object BloomFilterAggregate {
+  final def serialize(obj: BloomFilter): Array[Byte] = {
+    // BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence
+    // the +8
+    val size = (obj.bitSize() / 8) + 8
+    require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size")
+    val out = new ByteArrayOutputStream(size.intValue())
+    obj.writeTo(out)
+    out.close()
+    out.toByteArray
+  }
+
+  final def deserialize(bytes: Array[Byte]): BloomFilter = {
+    val in = new ByteArrayInputStream(bytes)
+    val bloomFilter = BloomFilter.readFrom(in)
+    in.close()
+    bloomFilter
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 6974ada..2c879be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -360,6 +360,8 @@ case class Invoke(
 
   lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE)
+
   override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
   override def children: Seq[Expression] = targetObject +: arguments
   override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a2fd668..d16e09c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -287,6 +287,22 @@ trait PredicateHelper extends AliasHelper with Logging {
       }
     }
   }
+
+  /**
+   * Returns whether an expression is likely to be selective
+   */
+  def isLikelySelective(e: Expression): Boolean = e match {
+    case Not(expr) => isLikelySelective(expr)
+    case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
+    case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
+    case _: StringRegexExpression => true
+    case _: BinaryComparison => true
+    case _: In | _: InSet => true
+    case _: StringPredicate => true
+    case BinaryPredicate(_) => true
+    case _: MultiLikeBase => true
+    case _ => false
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 368cbfd..bfaaba5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
 import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -627,6 +627,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
   @transient private var lastReplacementInUTF8: UTF8String = _
   // result buffer write by Matcher
   @transient private lazy val result: StringBuffer = new StringBuffer
+  final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE)
 
   override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = {
     if (!p.equals(lastRegex)) {
@@ -751,6 +752,8 @@ abstract class RegExpExtractBase
   // last regex pattern, we cache it for performance concern
   @transient private var pattern: Pattern = _
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY)
+
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
   override def first: Expression = subject
   override def second: Expression = regexp
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
new file mode 100644
index 0000000..35d0189
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete}
+import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+/**
+ * Insert a filter on one side of the join if the other side has a selective predicate.
+ * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
+ * else in the future.
+ */
+object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {
+
+  // Wraps `expr` with a hash function if its byte size is larger than an integer.
+  private def mayWrapWithHash(expr: Expression): Expression = {
+    if (expr.dataType.defaultSize > IntegerType.defaultSize) {
+      new Murmur3Hash(Seq(expr))
+    } else {
+      expr
+    }
+  }
+
+  private def injectFilter(
+      filterApplicationSideExp: Expression,
+      filterApplicationSidePlan: LogicalPlan,
+      filterCreationSideExp: Expression,
+      filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+    require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled)
+    if (conf.runtimeFilterBloomFilterEnabled) {
+      injectBloomFilter(
+        filterApplicationSideExp,
+        filterApplicationSidePlan,
+        filterCreationSideExp,
+        filterCreationSidePlan
+      )
+    } else {
+      injectInSubqueryFilter(
+        filterApplicationSideExp,
+        filterApplicationSidePlan,
+        filterCreationSideExp,
+        filterCreationSidePlan
+      )
+    }
+  }
+
+  private def injectBloomFilter(
+      filterApplicationSideExp: Expression,
+      filterApplicationSidePlan: LogicalPlan,
+      filterCreationSideExp: Expression,
+      filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+    // Skip if the filter creation side is too big
+    if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) {
+      return filterApplicationSidePlan
+    }
+    val rowCount = filterCreationSidePlan.stats.rowCount
+    val bloomFilterAgg =
+      if (rowCount.isDefined && rowCount.get.longValue > 0L) {
+        new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
+          Literal(rowCount.get.longValue))
+      } else {
+        new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
+      }
+    val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
+    val alias = Alias(aggExp, "bloomFilter")()
+    val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan))
+    val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
+    val filter = BloomFilterMightContain(bloomFilterSubquery,
+      new XxHash64(Seq(filterApplicationSideExp)))
+    Filter(filter, filterApplicationSidePlan)
+  }
+
+  private def injectInSubqueryFilter(
+      filterApplicationSideExp: Expression,
+      filterApplicationSidePlan: LogicalPlan,
+      filterCreationSideExp: Expression,
+      filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+    require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
+    val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
+    val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
+    val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
+    if (!canBroadcastBySize(aggregate, conf)) {
+      // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
+      // i.e., the semi-join will be a shuffled join, which is not worthwhile.
+      return filterApplicationSidePlan
+    }
+    val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)),
+      ListQuery(aggregate, childOutputs = aggregate.output))
+    Filter(filter, filterApplicationSidePlan)
+  }
+
+  /**
+   * Returns whether the plan is a simple filter over scan and the filter is likely selective
+   * Also check if the plan only has simple expressions (attribute reference, literals) so that we
+   * do not add a subquery that might have an expensive computation
+   */
+  private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
+    val ret = plan match {
+      case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] =>
+        filters.forall(isSimpleExpression) &&
+          filters.exists(isLikelySelective)
+      case _ => false
+    }
+    !plan.isStreaming && ret
+  }
+
+  private def isSimpleExpression(e: Expression): Boolean = {
+    !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
+      REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE)
+  }
+
+  private def canFilterLeft(joinType: JoinType): Boolean = joinType match {
+    case Inner | RightOuter => true
+    case _ => false
+  }
+
+  private def canFilterRight(joinType: JoinType): Boolean = joinType match {
+    case Inner | LeftOuter => true
+    case _ => false
+  }
+
+  private def isProbablyShuffleJoin(left: LogicalPlan,
+      right: LogicalPlan, hint: JoinHint): Boolean = {
+    !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) &&
+      !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf)
+  }
+
+  private def probablyHasShuffle(plan: LogicalPlan): Boolean = {
+    plan.collectFirst {
+      case j@Join(left, right, _, _, hint)
+        if isProbablyShuffleJoin(left, right, hint) => j
+      case a: Aggregate => a
+    }.nonEmpty
+  }
+
+  // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`.
+  private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = {
+    val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES)
+    filterApplicationSide.collect({
+      case leaf: LeafNode => leaf
+    }).map(scan => {
+      // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid
+      // creating a Bloom filter when the filter application side is very small, so using 0
+      // as the byte size when the actual size is unknown can avoid regression by applying BF
+      // on a small table.
+      if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes
+    }).max
+  }
+
+  // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a
+  // Bloom filter; false otherwise.
+  private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = {
+    // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom
+    // filter to each small task might be more costly than scanning them itself. Thus, we use max
+    // rather than sum here.
+    val maxScanSize = maxScanByteSize(filterApplicationSide)
+    maxScanSize >=
+      conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD)
+  }
+
+  /**
+   * Check that:
+   * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
+   *   expression references originate from a single leaf node)
+   * - The filter creation side has a selective predicate
+   * - The current join is a shuffle join or a broadcast join that has a shuffle below it
+   * - The max filterApplicationSide scan size is greater than a configurable threshold
+   */
+  private def filteringHasBenefit(
+      filterApplicationSide: LogicalPlan,
+      filterCreationSide: LogicalPlan,
+      filterApplicationSideExp: Expression,
+      hint: JoinHint): Boolean = {
+    findExpressionAndTrackLineageDown(filterApplicationSideExp,
+      filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) &&
+      (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) ||
+        probablyHasShuffle(filterApplicationSide)) &&
+      satisfyByteSizeRequirement(filterApplicationSide)
+  }
+
+  def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression,
+      rightKey: Expression): Boolean = {
+    if (conf.runtimeFilterBloomFilterEnabled) {
+      hasBloomFilter(left, right, leftKey, rightKey)
+    } else {
+      hasInSubquery(left, right, leftKey, rightKey)
+    }
+  }
+
+  // This checks if there is already a DPP filter, as this rule is called just after DPP.
+  def hasDynamicPruningSubquery(
+      left: LogicalPlan,
+      right: LogicalPlan,
+      leftKey: Expression,
+      rightKey: Expression): Boolean = {
+    (left, right) match {
+      case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) =>
+        pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
+      case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) =>
+        pruningKey.fastEquals(rightKey) ||
+          hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
+      case _ => false
+    }
+  }
+
+  def hasBloomFilter(
+      left: LogicalPlan,
+      right: LogicalPlan,
+      leftKey: Expression,
+      rightKey: Expression): Boolean = {
+    findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey)
+  }
+
+  private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = {
+    plan.find {
+      case Filter(condition, _) =>
+        splitConjunctivePredicates(condition).exists {
+          case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _))
+            if valueExpression.fastEquals(key) => true
+          case _ => false
+        }
+      case _ => false
+    }.isDefined
+  }
+
+  def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression,
+      rightKey: Expression): Boolean = {
+    (left, right) match {
+      case (Filter(InSubquery(Seq(key),
+      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) =>
+        key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey)))
+      case (_, Filter(InSubquery(Seq(key),
+      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) =>
+        key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey)))
+      case _ => false
+    }
+  }
+
+  private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = {
+    var filterCounter = 0
+    val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD)
+    plan transformUp {
+      case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) =>
+        var newLeft = left
+        var newRight = right
+        (leftKeys, rightKeys).zipped.foreach((l, r) => {
+          // Check if:
+          // 1. There is already a DPP filter on the key
+          // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key
+          // 3. The keys are simple cheap expressions
+          if (filterCounter < numFilterThreshold &&
+            !hasDynamicPruningSubquery(left, right, l, r) &&
+            !hasRuntimeFilter(newLeft, newRight, l, r) &&
+            isSimpleExpression(l) && isSimpleExpression(r)) {
+            val oldLeft = newLeft
+            val oldRight = newRight
+            if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) {
+              newLeft = injectFilter(l, newLeft, r, right)
+            }
+            // Did we actually inject on the left? If not, try on the right
+            if (newLeft.fastEquals(oldLeft) && canFilterRight(joinType) &&
+              filteringHasBenefit(right, left, r, hint)) {
+              newRight = injectFilter(r, newRight, l, left)
+            }
+            if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) {
+              filterCounter = filterCounter + 1
+            }
+          }
+        })
+        join.withNewChildren(Seq(newLeft, newRight))
+    }
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    case s: Subquery if s.correlated => plan
+    case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
+      !conf.runtimeFilterBloomFilterEnabled => plan
+    case _ => tryInjectRuntimeFilter(plan)
+  }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b595966..3cf45d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -54,6 +54,7 @@ object TreePattern extends Enumeration  {
   val IN_SUBQUERY: Value = Value
   val INSET: Value = Value
   val INTERSECT: Value = Value
+  val INVOKE: Value = Value
   val JSON_TO_STRUCT: Value = Value
   val LAMBDA_FUNCTION: Value = Value
   val LAMBDA_VARIABLE: Value = Value
@@ -72,6 +73,8 @@ object TreePattern extends Enumeration  {
   val PIVOT: Value = Value
   val PLAN_EXPRESSION: Value = Value
   val PYTHON_UDF: Value = Value
+  val REGEXP_EXTRACT_FAMILY: Value = Value
+  val REGEXP_REPLACE: Value = Value
   val RUNTIME_REPLACEABLE: Value = Value
   val SCALAR_SUBQUERY: Value = Value
   val SCALA_UDF: Value = Value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 3314dd1..1bba8b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -341,6 +341,77 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED =
+    buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled")
+      .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " +
+        "to insert a semi join in the other side to reduce the amount of shuffle data.")
+      .version("3.3.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val RUNTIME_FILTER_NUMBER_THRESHOLD =
+    buildConf("spark.sql.optimizer.runtimeFilter.number.threshold")
+      .doc("The total number of injected runtime filters (non-DPP) for a single " +
+        "query. This is to prevent driver OOMs with too many Bloom filters.")
+      .version("3.3.0")
+      .intConf
+      .checkValue(threshold => threshold >= 0, "The threshold should be >= 0")
+      .createWithDefault(10)
+
+  val RUNTIME_BLOOM_FILTER_ENABLED =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled")
+      .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " +
+        "to insert a bloom filter in the other side to reduce the amount of shuffle data.")
+      .version("3.3.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold")
+      .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " +
+        "under this value to try to inject bloom filter.")
+      .version("3.3.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("10MB")
+
+  val RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold")
+      .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " +
+        "size. Aggregated scan byte size of the Bloom filter application side needs to be over " +
+        "this value to inject a bloom filter.")
+      .version("3.3.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("10GB")
+
+  val RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.expectedNumItems")
+      .doc("The default number of expected items for the runtime bloomfilter")
+      .version("3.3.0")
+      .longConf
+      .createWithDefault(1000000L)
+
+  val RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumItems")
+      .doc("The max allowed number of expected items for the runtime bloom filter")
+      .version("3.3.0")
+      .longConf
+      .createWithDefault(4000000L)
+
+
+  val RUNTIME_BLOOM_FILTER_NUM_BITS =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.numBits")
+      .doc("The default number of bits to use for the runtime bloom filter")
+      .version("3.3.0")
+      .longConf
+      .createWithDefault(8388608L)
+
+  val RUNTIME_BLOOM_FILTER_MAX_NUM_BITS =
+    buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumBits")
+      .doc("The max number of bits to use for the runtime bloom filter")
+      .version("3.3.0")
+      .longConf
+      .createWithDefault(67108864L)
+
   val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
     .doc("When set to true Spark SQL will automatically select a compression codec for each " +
       "column based on statistics of the data.")
@@ -3750,6 +3821,15 @@ class SQLConf extends Serializable with Logging {
   def dynamicPartitionPruningReuseBroadcastOnly: Boolean =
     getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY)
 
+  def runtimeFilterSemiJoinReductionEnabled: Boolean =
+    getConf(RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED)
+
+  def runtimeFilterBloomFilterEnabled: Boolean =
+    getConf(RUNTIME_BLOOM_FILTER_ENABLED)
+
+  def runtimeFilterCreationSideThreshold: Long =
+    getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD)
+
   def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
 
   def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 7e8fb4a..743cb59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -43,6 +43,8 @@ class SparkOptimizer(
     Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
     Batch("PartitionPruning", Once,
       PartitionPruning) :+
+    Batch("InjectRuntimeFilter", FixedPoint(1),
+      InjectRuntimeFilter) :+
     Batch("Pushdown Filters from PartitionPruning", fixedPoint,
       PushDownPredicates) :+
     Batch("Cleanup filters that cannot be pushed down", Once,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
index 3b5fc4a..89d6603 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
@@ -194,21 +194,6 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper {
     scanOverhead + cachedOverhead
   }
 
-  /**
-   * Returns whether an expression is likely to be selective
-   */
-  private def isLikelySelective(e: Expression): Boolean = e match {
-    case Not(expr) => isLikelySelective(expr)
-    case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
-    case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
-    case _: StringRegexExpression => true
-    case _: BinaryComparison => true
-    case _: In | _: InSet => true
-    case _: StringPredicate => true
-    case BinaryPredicate(_) => true
-    case _: MultiLikeBase => true
-    case _ => false
-  }
 
   /**
    * Search a filtering predicate in a given logical plan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
new file mode 100644
index 0000000..025593b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Query tests for the Bloom filter aggregate and filter function.
+ */
+class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
+
+  val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+  val funcId_might_contain = new FunctionIdentifier("might_contain")
+
+  // Register 'bloom_filter_agg' to builtin.
+  FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg,
+    new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
+    (children: Seq[Expression]) => children.size match {
+      case 1 => new BloomFilterAggregate(children.head)
+      case 2 => new BloomFilterAggregate(children.head, children(1))
+      case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
+    })
+
+  // Register 'might_contain' to builtin.
+  FunctionRegistry.builtin.registerFunction(funcId_might_contain,
+    new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
+    (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
+
+  override def afterAll(): Unit = {
+    FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg)
+    FunctionRegistry.builtin.dropFunction(funcId_might_contain)
+    super.afterAll()
+  }
+
+  test("Test bloom_filter_agg and might_contain") {
+    val conf = SQLConf.get
+    val table = "bloom_filter_test"
+    for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+      conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
+      for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+        conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) {
+        val sqlString = s"""
+                           |SELECT every(might_contain(
+                           |            (SELECT bloom_filter_agg(col,
+                           |              cast($numEstimatedItems as long),
+                           |              cast($numBits as long))
+                           |             FROM $table),
+                           |            col)) positive_membership_test,
+                           |       every(might_contain(
+                           |            (SELECT bloom_filter_agg(col,
+                           |              cast($numEstimatedItems as long),
+                           |              cast($numBits as long))
+                           |             FROM values (-1L), (100001L), (20000L) as t(col)),
+                           |            col)) negative_membership_test
+                           |FROM $table
+           """.stripMargin
+        withTempView(table) {
+          (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L))
+            .toDF("col").createOrReplaceTempView(table)
+          // Validate error messages as well as answers when there's no error.
+          if (numEstimatedItems <= 0) {
+            val exception = intercept[AnalysisException] {
+              spark.sql(sqlString)
+            }
+            assert(exception.getMessage.contains(
+              "The estimated number of items must be a positive value"))
+          } else if (numBits <= 0) {
+            val exception = intercept[AnalysisException] {
+              spark.sql(sqlString)
+            }
+            assert(exception.getMessage.contains("The number of bits must be a positive value"))
+          } else {
+            checkAnswer(spark.sql(sqlString), Row(true, false))
+          }
+        }
+      }
+    }
+  }
+
+  test("Test that bloom_filter_agg errors out disallowed input value types") {
+    val exception1 = intercept[AnalysisException] {
+      spark.sql("""
+        |SELECT bloom_filter_agg(a)
+        |FROM values (1.2), (2.5) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception1.getMessage.contains(
+      "Input to function bloom_filter_agg should have been a bigint value"))
+
+    val exception2 = intercept[AnalysisException] {
+      spark.sql("""
+        |SELECT bloom_filter_agg(a, 2)
+        |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception2.getMessage.contains(
+      "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+    val exception3 = intercept[AnalysisException] {
+      spark.sql("""
+        |SELECT bloom_filter_agg(a, cast(2 as long), 5)
+        |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception3.getMessage.contains(
+      "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+    val exception4 = intercept[AnalysisException] {
+      spark.sql("""
+        |SELECT bloom_filter_agg(a, null, 5)
+        |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments"))
+
+    val exception5 = intercept[AnalysisException] {
+      spark.sql("""
+        |SELECT bloom_filter_agg(a, 5, null)
+        |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments"))
+  }
+
+  test("Test that might_contain errors out disallowed input value types") {
+    val exception1 = intercept[AnalysisException] {
+      spark.sql("""|SELECT might_contain(1.0, 1L)"""
+        .stripMargin)
+    }
+    assert(exception1.getMessage.contains(
+      "Input to function might_contain should have been binary followed by a value with bigint"))
+
+    val exception2 = intercept[AnalysisException] {
+      spark.sql("""|SELECT might_contain(NULL, 0.1)"""
+        .stripMargin)
+    }
+    assert(exception2.getMessage.contains(
+      "Input to function might_contain should have been binary followed by a value with bigint"))
+  }
+
+  test("Test that might_contain errors out non-constant Bloom filter") {
+    val exception1 = intercept[AnalysisException] {
+      spark.sql("""
+                  |SELECT might_contain(cast(a as binary), cast(5 as long))
+                  |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception1.getMessage.contains(
+      "The Bloom filter binary input to might_contain should be either a constant value or " +
+        "a scalar subquery expression"))
+
+    val exception2 = intercept[AnalysisException] {
+      spark.sql("""
+                  |SELECT might_contain((select cast(a as binary)), cast(5 as long))
+                  |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+        .stripMargin)
+    }
+    assert(exception2.getMessage.contains(
+      "The Bloom filter binary input to might_contain should be either a constant value or " +
+        "a scalar subquery expression"))
+  }
+
+  test("Test that might_contain can take a constant value input") {
+    checkAnswer(spark.sql(
+      """SELECT might_contain(
+        |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
+        |cast(201 as long))""".stripMargin),
+      Row(false))
+  }
+
+  test("Test that bloom_filter_agg produces a NULL with empty input") {
+    checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""),
+      Row(null))
+  }
+
+  test("Test NULL inputs for might_contain") {
+    checkAnswer(spark.sql(
+      s"""
+         |SELECT might_contain(null, null) both_null,
+         |       might_contain(null, 1L) null_bf,
+         |       might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)),
+         |            null) null_value
+         """.stripMargin),
+      Row(null, null, null))
+  }
+
+  test("Test that a query with bloom_filter_agg has partial aggregates") {
+    assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""")
+      .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
+      .collect({case agg: BaseAggregateExec => agg}).size == 2)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
new file mode 100644
index 0000000..a5e27fb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -0,0 +1,503 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession {
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    val schema = new StructType().add("a1", IntegerType, nullable = true)
+      .add("b1", IntegerType, nullable = true)
+      .add("c1", IntegerType, nullable = true)
+      .add("d1", IntegerType, nullable = true)
+      .add("e1", IntegerType, nullable = true)
+      .add("f1", IntegerType, nullable = true)
+
+    val data1 = Seq(Seq(null, 47, null, 4, 6, 48),
+      Seq(73, 63, null, 92, null, null),
+      Seq(76, 10, 74, 98, 37, 5),
+      Seq(0, 63, null, null, null, null),
+      Seq(15, 77, null, null, null, null),
+      Seq(null, 57, 33, 55, null, 58),
+      Seq(4, 0, 86, null, 96, 14),
+      Seq(28, 16, 58, null, null, null),
+      Seq(1, 88, null, 8, null, 79),
+      Seq(59, null, null, null, 20, 25),
+      Seq(1, 50, null, 94, 94, null),
+      Seq(null, null, null, 67, 51, 57),
+      Seq(77, 50, 8, 90, 16, 21),
+      Seq(34, 28, null, 5, null, 64),
+      Seq(null, null, 88, 11, 63, 79),
+      Seq(92, 94, 23, 1, null, 64),
+      Seq(57, 56, null, 83, null, null),
+      Seq(null, 35, 8, 35, null, 70),
+      Seq(null, 8, null, 35, null, 87),
+      Seq(9, null, null, 60, null, 5),
+      Seq(null, 15, 66, null, 83, null))
+    val rdd1 = spark.sparkContext.parallelize(data1)
+    val rddRow1 = rdd1.map(s => Row.fromSeq(s))
+    spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1")
+
+    val schema2 = new StructType().add("a2", IntegerType, nullable = true)
+      .add("b2", IntegerType, nullable = true)
+      .add("c2", IntegerType, nullable = true)
+      .add("d2", IntegerType, nullable = true)
+      .add("e2", IntegerType, nullable = true)
+      .add("f2", IntegerType, nullable = true)
+
+
+    val data2 = Seq(Seq(67, 17, 45, 91, null, null),
+      Seq(98, 63, 0, 89, null, 40),
+      Seq(null, 76, 68, 75, 20, 19),
+      Seq(8, null, null, null, 78, null),
+      Seq(48, 62, null, null, 11, 98),
+      Seq(84, null, 99, 65, 66, 51),
+      Seq(98, null, null, null, 42, 51),
+      Seq(10, 3, 29, null, 68, 8),
+      Seq(85, 36, 41, null, 28, 71),
+      Seq(89, null, 94, 95, 67, 21),
+      Seq(44, null, 24, 33, null, 6),
+      Seq(null, 6, 78, 31, null, 69),
+      Seq(59, 2, 63, 9, 66, 20),
+      Seq(5, 23, 10, 86, 68, null),
+      Seq(null, 63, 99, 55, 9, 65),
+      Seq(57, 62, 68, 5, null, 0),
+      Seq(75, null, 15, null, 81, null),
+      Seq(53, null, 6, 68, 28, 13),
+      Seq(null, null, null, null, 89, 23),
+      Seq(36, 73, 40, null, 8, null),
+      Seq(24, null, null, 40, null, null))
+    val rdd2 = spark.sparkContext.parallelize(data2)
+    val rddRow2 = rdd2.map(s => Row.fromSeq(s))
+    spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2")
+
+    val schema3 = new StructType().add("a3", IntegerType, nullable = true)
+      .add("b3", IntegerType, nullable = true)
+      .add("c3", IntegerType, nullable = true)
+      .add("d3", IntegerType, nullable = true)
+      .add("e3", IntegerType, nullable = true)
+      .add("f3", IntegerType, nullable = true)
+
+    val data3 = Seq(Seq(67, 17, 45, 91, null, null),
+      Seq(98, 63, 0, 89, null, 40),
+      Seq(null, 76, 68, 75, 20, 19),
+      Seq(8, null, null, null, 78, null),
+      Seq(48, 62, null, null, 11, 98),
+      Seq(84, null, 99, 65, 66, 51),
+      Seq(98, null, null, null, 42, 51),
+      Seq(10, 3, 29, null, 68, 8),
+      Seq(85, 36, 41, null, 28, 71),
+      Seq(89, null, 94, 95, 67, 21),
+      Seq(44, null, 24, 33, null, 6),
+      Seq(null, 6, 78, 31, null, 69),
+      Seq(59, 2, 63, 9, 66, 20),
+      Seq(5, 23, 10, 86, 68, null),
+      Seq(null, 63, 99, 55, 9, 65),
+      Seq(57, 62, 68, 5, null, 0),
+      Seq(75, null, 15, null, 81, null),
+      Seq(53, null, 6, 68, 28, 13),
+      Seq(null, null, null, null, 89, 23),
+      Seq(36, 73, 40, null, 8, null),
+      Seq(24, null, null, 40, null, null))
+    val rdd3 = spark.sparkContext.parallelize(data3)
+    val rddRow3 = rdd3.map(s => Row.fromSeq(s))
+    spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3")
+
+
+    val schema4 = new StructType().add("a4", IntegerType, nullable = true)
+      .add("b4", IntegerType, nullable = true)
+      .add("c4", IntegerType, nullable = true)
+      .add("d4", IntegerType, nullable = true)
+      .add("e4", IntegerType, nullable = true)
+      .add("f4", IntegerType, nullable = true)
+
+    val data4 = Seq(Seq(67, 17, 45, 91, null, null),
+      Seq(98, 63, 0, 89, null, 40),
+      Seq(null, 76, 68, 75, 20, 19),
+      Seq(8, null, null, null, 78, null),
+      Seq(48, 62, null, null, 11, 98),
+      Seq(84, null, 99, 65, 66, 51),
+      Seq(98, null, null, null, 42, 51),
+      Seq(10, 3, 29, null, 68, 8),
+      Seq(85, 36, 41, null, 28, 71),
+      Seq(89, null, 94, 95, 67, 21),
+      Seq(44, null, 24, 33, null, 6),
+      Seq(null, 6, 78, 31, null, 69),
+      Seq(59, 2, 63, 9, 66, 20),
+      Seq(5, 23, 10, 86, 68, null),
+      Seq(null, 63, 99, 55, 9, 65),
+      Seq(57, 62, 68, 5, null, 0),
+      Seq(75, null, 15, null, 81, null),
+      Seq(53, null, 6, 68, 28, 13),
+      Seq(null, null, null, null, 89, 23),
+      Seq(36, 73, 40, null, 8, null),
+      Seq(24, null, null, 40, null, null))
+    val rdd4 = spark.sparkContext.parallelize(data4)
+    val rddRow4 = rdd4.map(s => Row.fromSeq(s))
+    spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4")
+
+    val schema5part = new StructType().add("a5", IntegerType, nullable = true)
+      .add("b5", IntegerType, nullable = true)
+      .add("c5", IntegerType, nullable = true)
+      .add("d5", IntegerType, nullable = true)
+      .add("e5", IntegerType, nullable = true)
+      .add("f5", IntegerType, nullable = true)
+
+    val data5part = Seq(Seq(67, 17, 45, 91, null, null),
+      Seq(98, 63, 0, 89, null, 40),
+      Seq(null, 76, 68, 75, 20, 19),
+      Seq(8, null, null, null, 78, null),
+      Seq(48, 62, null, null, 11, 98),
+      Seq(84, null, 99, 65, 66, 51),
+      Seq(98, null, null, null, 42, 51),
+      Seq(10, 3, 29, null, 68, 8),
+      Seq(85, 36, 41, null, 28, 71),
+      Seq(89, null, 94, 95, 67, 21),
+      Seq(44, null, 24, 33, null, 6),
+      Seq(null, 6, 78, 31, null, 69),
+      Seq(59, 2, 63, 9, 66, 20),
+      Seq(5, 23, 10, 86, 68, null),
+      Seq(null, 63, 99, 55, 9, 65),
+      Seq(57, 62, 68, 5, null, 0),
+      Seq(75, null, 15, null, 81, null),
+      Seq(53, null, 6, 68, 28, 13),
+      Seq(null, null, null, null, 89, 23),
+      Seq(36, 73, 40, null, 8, null),
+      Seq(24, null, null, 40, null, null))
+    val rdd5part = spark.sparkContext.parallelize(data5part)
+    val rddRow5part = rdd5part.map(s => Row.fromSeq(s))
+    spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5")
+      .saveAsTable("bf5part")
+    spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30")
+      .write.partitionBy("f5")
+      .saveAsTable("bf5filtered")
+
+    sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1")
+    sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2")
+    sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3")
+    sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4")
+    sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5")
+    sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5")
+  }
+
+  protected override def afterAll(): Unit = try {
+    sql("DROP TABLE IF EXISTS bf1")
+    sql("DROP TABLE IF EXISTS bf2")
+    sql("DROP TABLE IF EXISTS bf3")
+    sql("DROP TABLE IF EXISTS bf4")
+    sql("DROP TABLE IF EXISTS bf5part")
+    sql("DROP TABLE IF EXISTS bf5filtered")
+  } finally {
+    super.afterAll()
+  }
+
+  def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean,
+      shouldReplace: Boolean): Unit = {
+    var planDisabled: LogicalPlan = null
+    var planEnabled: LogicalPlan = null
+    var expectedAnswer: Array[Row] = null
+
+    withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+      SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+      planDisabled = sql(query).queryExecution.optimizedPlan
+      expectedAnswer = sql(query).collect()
+    }
+
+    if (testSemiJoin) {
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+        planEnabled = sql(query).queryExecution.optimizedPlan
+        checkAnswer(sql(query), expectedAnswer)
+      }
+      if (shouldReplace) {
+        val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled))
+        val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled))
+        assert(normalizedEnabled != normalizedDisabled)
+      } else {
+        comparePlans(planDisabled, planEnabled)
+      }
+    } else {
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+        planEnabled = sql(query).queryExecution.optimizedPlan
+        checkAnswer(sql(query), expectedAnswer)
+        if (shouldReplace) {
+          assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled))
+        } else {
+          assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled))
+        }
+      }
+    }
+  }
+
+  def getNumBloomFilters(plan: LogicalPlan): Integer = {
+    val numBloomFilterAggs = plan.collect {
+      case Filter(condition, _) => condition.collect {
+        case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+        => subquery.plan.collect {
+          case Aggregate(_, aggregateExpressions, _) =>
+            aggregateExpressions.map {
+              case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _),
+              _) =>
+                assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal])
+                assert(bfAgg.numBitsExpression.isInstanceOf[Literal])
+                1
+            }.sum
+        }.sum
+      }.sum
+    }.sum
+    val numMightContains = plan.collect {
+      case Filter(condition, _) => condition.collect {
+        case BloomFilterMightContain(_, _) => 1
+      }.sum
+    }.sum
+    assert(numBloomFilterAggs == numMightContains)
+    numMightContains
+  }
+
+  def assertRewroteSemiJoin(query: String): Unit = {
+    checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true)
+  }
+
+  def assertDidNotRewriteSemiJoin(query: String): Unit = {
+    checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false)
+  }
+
+  def assertRewroteWithBloomFilter(query: String): Unit = {
+    checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true)
+  }
+
+  def assertDidNotRewriteWithBloomFilter(query: String): Unit = {
+    checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false)
+  }
+
+  test("Runtime semi join reduction: simple") {
+    // Filter creation side is 3409 bytes
+    // Filter application side scan is 3362 bytes
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62")
+      assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+    }
+  }
+
+  test("Runtime semi join reduction: two joins") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
+        "and bf3.c3 = bf2.c2 where bf2.a2 = 5")
+    }
+  }
+
+  test("Runtime semi join reduction: three joins") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " +
+        "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5")
+    }
+  }
+
+  test("Runtime semi join reduction: simple expressions only") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      val squared = (s: Long) => {
+        s * s
+      }
+      spark.udf.register("square", squared)
+      assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+        "bf1.c1 = bf2.c2 where square(bf2.a2) = 62")
+      assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+        "bf1.c1 = square(bf2.c2) where bf2.a2= 62")
+    }
+  }
+
+  test("Runtime bloom filter join: simple") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+        "where bf2.a2 = 62")
+      assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+    }
+  }
+
+  test("Runtime bloom filter join: two filters single join") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      var planDisabled: LogicalPlan = null
+      var planEnabled: LogicalPlan = null
+      var expectedAnswer: Array[Row] = null
+
+      val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+        "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+        planDisabled = sql(query).queryExecution.optimizedPlan
+        expectedAnswer = sql(query).collect()
+      }
+
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+        planEnabled = sql(query).queryExecution.optimizedPlan
+        checkAnswer(sql(query), expectedAnswer)
+      }
+      assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+    }
+  }
+
+  test("Runtime bloom filter join: test the number of filter threshold") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      var planDisabled: LogicalPlan = null
+      var planEnabled: LogicalPlan = null
+      var expectedAnswer: Array[Row] = null
+
+      val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+        "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+        planDisabled = sql(query).queryExecution.optimizedPlan
+        expectedAnswer = sql(query).collect()
+      }
+
+      for (numFilterThreshold <- 0 to 3) {
+        withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+          SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
+          SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) {
+          planEnabled = sql(query).queryExecution.optimizedPlan
+          checkAnswer(sql(query), expectedAnswer)
+        }
+        if (numFilterThreshold < 3) {
+          assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)
+            + numFilterThreshold)
+        } else {
+          assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+        }
+      }
+    }
+  }
+
+  test("Runtime bloom filter join: insert one bloom filter per column") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      var planDisabled: LogicalPlan = null
+      var planEnabled: LogicalPlan = null
+      var expectedAnswer: Array[Row] = null
+
+      val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+        "bf1.c1 = bf2.b2 where bf2.a2 = 62"
+
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+        planDisabled = sql(query).queryExecution.optimizedPlan
+        expectedAnswer = sql(query).collect()
+      }
+
+      withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+        SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+        planEnabled = sql(query).queryExecution.optimizedPlan
+        checkAnswer(sql(query), expectedAnswer)
+      }
+      assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1)
+    }
+  }
+
+  test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " +
+    "on the same column") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " +
+        "bf5part.f5 = bf2.c2 where bf2.a2 = 62")
+    }
+  }
+
+  test("Runtime bloom filter join: add bloom filter if dpp filter exists on " +
+    "a different column") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      assertRewroteWithBloomFilter("select * from bf5part join bf2 on " +
+        "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62")
+    }
+  }
+
+  test("Runtime bloom filter join: BF rewrite triggering threshold test") {
+    // Filter creation side data size is 3409 bytes. On the filter application side, an individual
+    // scan's byte size is 3362.
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+      SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+    ) {
+      assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+        "where bf2.a2 = 62")
+    }
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
+      SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50"
+    ) {
+      assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+        "where bf2.a2 = 62")
+    }
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+      SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+    ) {
+      // Rewrite should not be triggered as the Bloom filter application side scan size is small.
+      assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 "
+        + "where bf2.a2 = 62")
+    }
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32",
+      SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") {
+      // Test that the max scan size rather than an individual scan size on the filter
+      // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes.
+      withSQLConf(
+        SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") {
+        assertRewroteWithBloomFilter("select * from " +
+          "(select * from bf5filtered union all select * from bf2) t " +
+          "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+      }
+      withSQLConf(
+        SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") {
+        assertDidNotRewriteWithBloomFilter("select * from " +
+          "(select * from bf5filtered union all select * from bf2) t " +
+          "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+      }
+    }
+  }
+
+  test("Runtime bloom filter join: simple expressions only") {
+    withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      val squared = (s: Long) => {
+        s * s
+      }
+      spark.udf.register("square", squared)
+      assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+        "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" )
+      assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+        "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" )
+    }
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org