You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hudi.apache.org by ga...@apache.org on 2021/05/29 14:50:39 UTC

[hudi] branch master updated: [HUDI-1879] Support Partition Prune For MergeOnRead Snapshot Table (#2926)

This is an automated email from the ASF dual-hosted git repository.

garyli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new dcd7c33  [HUDI-1879] Support Partition Prune For MergeOnRead Snapshot Table (#2926)
dcd7c33 is described below

commit dcd7c331dc72df9ab10e4867a3592faf89f1480b
Author: pengzhiwei <pe...@icloud.com>
AuthorDate: Sat May 29 22:50:24 2021 +0800

    [HUDI-1879] Support Partition Prune For MergeOnRead Snapshot Table (#2926)
---
 .../scala/org/apache/hudi/HoodieFileIndex.scala    |   9 +-
 .../scala/org/apache/hudi/HoodieSparkUtils.scala   |  96 ++++++++++++
 .../apache/hudi/MergeOnReadSnapshotRelation.scala  |  20 ++-
 .../TestConvertFilterToCatalystExpression.scala    | 165 +++++++++++++++++++++
 .../apache/hudi/functional/TestMORDataSource.scala |  65 +++++++-
 5 files changed, 350 insertions(+), 5 deletions(-)

diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
index 61c2f3a..8e7e1f8 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
@@ -262,7 +262,14 @@ case class HoodieFileIndex(
           // If the partition column size is not equal to the partition fragment size
           // and the partition column size is 1, we map the whole partition path
           // to the partition column which can benefit from the partition prune.
-          InternalRow.fromSeq(Seq(UTF8String.fromString(partitionPath)))
+          val prefix = s"${partitionSchema.fieldNames.head}="
+          val partitionValue = if (partitionPath.startsWith(prefix)) {
+            // support hive style partition path
+            partitionPath.substring(prefix.length)
+          } else {
+            partitionPath
+          }
+          InternalRow.fromSeq(Seq(UTF8String.fromString(partitionValue)))
         } else if (partitionFragments.length != partitionSchema.fields.length &&
           partitionSchema.fields.length > 1) {
           // If the partition column size is not equal to the partition fragments size
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
index 72b26be..ee83cf4 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
@@ -28,8 +28,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.sql.avro.SchemaConverters
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal}
 import org.apache.spark.sql.execution.datasources.{FileStatusCache, InMemoryFileIndex, Spark2ParsePartitionUtil, Spark3ParsePartitionUtil, SparkParsePartitionUtil}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
 import org.apache.spark.sql.types.{StringType, StructField, StructType}
 
 import scala.collection.JavaConverters._
@@ -128,4 +130,98 @@ object HoodieSparkUtils {
       new Spark3ParsePartitionUtil(conf)
     }
   }
+
+  /**
+   * Convert Filters to Catalyst Expressions and joined by And. If convert success return an
+   * Non-Empty Option[Expression],or else return None.
+   */
+  def convertToCatalystExpressions(filters: Array[Filter],
+                                   tableSchema: StructType): Option[Expression] = {
+    val expressions = filters.map(convertToCatalystExpression(_, tableSchema))
+    if (expressions.forall(p => p.isDefined)) {
+      if (expressions.isEmpty) {
+        None
+      } else if (expressions.length == 1) {
+        expressions(0)
+      } else {
+        Some(expressions.map(_.get).reduce(org.apache.spark.sql.catalyst.expressions.And))
+      }
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Convert Filter to Catalyst Expression. If convert success return an Non-Empty
+   * Option[Expression],or else return None.
+   */
+  def convertToCatalystExpression(filter: Filter, tableSchema: StructType): Option[Expression] = {
+    Option(
+      filter match {
+        case EqualTo(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.EqualTo(toAttribute(attribute, tableSchema), Literal.create(value))
+        case EqualNullSafe(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.EqualNullSafe(toAttribute(attribute, tableSchema), Literal.create(value))
+        case GreaterThan(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.GreaterThan(toAttribute(attribute, tableSchema), Literal.create(value))
+        case GreaterThanOrEqual(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value))
+        case LessThan(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.LessThan(toAttribute(attribute, tableSchema), Literal.create(value))
+        case LessThanOrEqual(attribute, value) =>
+          org.apache.spark.sql.catalyst.expressions.LessThanOrEqual(toAttribute(attribute, tableSchema), Literal.create(value))
+        case In(attribute, values) =>
+          val attrExp = toAttribute(attribute, tableSchema)
+          val valuesExp = values.map(v => Literal.create(v))
+          org.apache.spark.sql.catalyst.expressions.In(attrExp, valuesExp)
+        case IsNull(attribute) =>
+          org.apache.spark.sql.catalyst.expressions.IsNull(toAttribute(attribute, tableSchema))
+        case IsNotNull(attribute) =>
+          org.apache.spark.sql.catalyst.expressions.IsNotNull(toAttribute(attribute, tableSchema))
+        case And(left, right) =>
+          val leftExp = convertToCatalystExpression(left, tableSchema)
+          val rightExp = convertToCatalystExpression(right, tableSchema)
+          if (leftExp.isEmpty || rightExp.isEmpty) {
+            null
+          } else {
+            org.apache.spark.sql.catalyst.expressions.And(leftExp.get, rightExp.get)
+          }
+        case Or(left, right) =>
+          val leftExp = convertToCatalystExpression(left, tableSchema)
+          val rightExp = convertToCatalystExpression(right, tableSchema)
+          if (leftExp.isEmpty || rightExp.isEmpty) {
+            null
+          } else {
+            org.apache.spark.sql.catalyst.expressions.Or(leftExp.get, rightExp.get)
+          }
+        case Not(child) =>
+          val childExp = convertToCatalystExpression(child, tableSchema)
+          if (childExp.isEmpty) {
+            null
+          } else {
+            org.apache.spark.sql.catalyst.expressions.Not(childExp.get)
+          }
+        case StringStartsWith(attribute, value) =>
+          val leftExp = toAttribute(attribute, tableSchema)
+          val rightExp = Literal.create(s"$value%")
+          org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
+        case StringEndsWith(attribute, value) =>
+          val leftExp = toAttribute(attribute, tableSchema)
+          val rightExp = Literal.create(s"%$value")
+          org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
+        case StringContains(attribute, value) =>
+          val leftExp = toAttribute(attribute, tableSchema)
+          val rightExp = Literal.create(s"%$value%")
+          org.apache.spark.sql.catalyst.expressions.Like(leftExp, rightExp)
+        case _=> null
+      }
+    )
+  }
+
+  private def toAttribute(columnName: String, tableSchema: StructType): AttributeReference = {
+    val field = tableSchema.find(p => p.name == columnName)
+    assert(field.isDefined, s"Cannot find column: $columnName, Table Columns are: " +
+      s"${tableSchema.fieldNames.mkString(",")}")
+    AttributeReference(columnName, field.get.dataType, field.get.nullable)()
+  }
 }
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
index c9d413b..13cf43e 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
@@ -67,7 +67,6 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
     DataSourceReadOptions.REALTIME_MERGE_OPT_KEY,
     DataSourceReadOptions.DEFAULT_REALTIME_MERGE_OPT_VAL)
   private val maxCompactionMemoryInBytes = getMaxCompactionMemoryInBytes(jobConf)
-  private val fileIndex = buildFileIndex()
   private val preCombineField = {
     val preCombineFieldFromTableConfig = metaClient.getTableConfig.getPreCombineField
     if (preCombineFieldFromTableConfig != null) {
@@ -94,6 +93,8 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
     })
     val requiredAvroSchema = AvroConversionUtils
       .convertStructTypeToAvroSchema(requiredStructSchema, tableAvroSchema.getName, tableAvroSchema.getNamespace)
+
+    val fileIndex = buildFileIndex(filters)
     val hoodieTableState = HoodieMergeOnReadTableState(
       tableStructSchema,
       requiredStructSchema,
@@ -131,7 +132,8 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
     rdd.asInstanceOf[RDD[Row]]
   }
 
-  def buildFileIndex(): List[HoodieMergeOnReadFileSplit] = {
+  def buildFileIndex(filters: Array[Filter]): List[HoodieMergeOnReadFileSplit] = {
+
     val fileStatuses = if (globPaths.isDefined) {
       // Load files from the global paths if it has defined to be compatible with the original mode
       val inMemoryFileIndex = HoodieSparkUtils.createInMemoryFileIndex(sqlContext.sparkSession, globPaths.get)
@@ -139,7 +141,19 @@ class MergeOnReadSnapshotRelation(val sqlContext: SQLContext,
     } else { // Load files by the HoodieFileIndex.
       val hoodieFileIndex = HoodieFileIndex(sqlContext.sparkSession, metaClient,
         Some(tableStructSchema), optParams, FileStatusCache.getOrCreate(sqlContext.sparkSession))
-      hoodieFileIndex.allFiles
+
+      // Get partition filter and convert to catalyst expression
+      val partitionColumns = hoodieFileIndex.partitionSchema.fieldNames.toSet
+      val partitionFilters = filters.filter(f => f.references.forall(p => partitionColumns.contains(p)))
+      val partitionFilterExpression =
+        HoodieSparkUtils.convertToCatalystExpressions(partitionFilters, tableStructSchema)
+
+      // if convert success to catalyst expression, use the partition prune
+      if (partitionFilterExpression.isDefined) {
+        hoodieFileIndex.listFiles(Seq(partitionFilterExpression.get), Seq.empty).flatMap(_.files)
+      } else {
+        hoodieFileIndex.allFiles
+      }
     }
 
     if (fileStatuses.isEmpty) { // If this an empty table, return an empty split list.
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
new file mode 100644
index 0000000..d1a1170
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi
+
+import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpressions
+import org.apache.hudi.HoodieSparkUtils.convertToCatalystExpression
+import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType}
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Test
+
+import scala.collection.mutable.ArrayBuffer
+
+class TestConvertFilterToCatalystExpression {
+
+  private lazy val tableSchema = {
+    val fields = new ArrayBuffer[StructField]()
+    fields.append(StructField("id", LongType, nullable = false))
+    fields.append(StructField("name", StringType, nullable = true))
+    fields.append(StructField("price", DoubleType, nullable = true))
+    fields.append(StructField("ts", IntegerType, nullable = false))
+    StructType(fields)
+  }
+
+  @Test
+  def testBaseConvert(): Unit = {
+    checkConvertFilter(eq("id", 1), "(`id` = 1)")
+    checkConvertFilter(eqs("name", "a1"), "(`name` <=> 'a1')")
+    checkConvertFilter(lt("price", 10), "(`price` < 10)")
+    checkConvertFilter(lte("ts", 1), "(`ts` <= 1)")
+    checkConvertFilter(gt("price", 10), "(`price` > 10)")
+    checkConvertFilter(gte("price", 10), "(`price` >= 10)")
+    checkConvertFilter(in("id", 1, 2 , 3), "(`id` IN (1, 2, 3))")
+    checkConvertFilter(isNull("id"), "(`id` IS NULL)")
+    checkConvertFilter(isNotNull("name"), "(`name` IS NOT NULL)")
+    checkConvertFilter(and(lt("ts", 10), gt("ts", 1)),
+      "((`ts` < 10) AND (`ts` > 1))")
+    checkConvertFilter(or(lte("ts", 10), gte("ts", 1)),
+      "((`ts` <= 10) OR (`ts` >= 1))")
+    checkConvertFilter(not(and(lt("ts", 10), gt("ts", 1))),
+      "(NOT ((`ts` < 10) AND (`ts` > 1)))")
+    checkConvertFilter(startWith("name", "ab"), "`name` LIKE 'ab%'")
+    checkConvertFilter(endWith("name", "cd"), "`name` LIKE '%cd'")
+    checkConvertFilter(contains("name", "e"), "`name` LIKE '%e%'")
+  }
+
+  @Test
+  def testConvertFilters(): Unit = {
+    checkConvertFilters(Array.empty[Filter], null)
+    checkConvertFilters(Array(eq("id", 1)), "(`id` = 1)")
+    checkConvertFilters(Array(lt("ts", 10), gt("ts", 1)),
+      "((`ts` < 10) AND (`ts` > 1))")
+  }
+
+  @Test
+  def testUnSupportConvert(): Unit = {
+    checkConvertFilters(Array(unsupport()), null)
+    checkConvertFilters(Array(and(unsupport(), eq("id", 1))), null)
+    checkConvertFilters(Array(or(unsupport(), eq("id", 1))), null)
+    checkConvertFilters(Array(and(eq("id", 1), not(unsupport()))), null)
+  }
+
+  private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = {
+    val exp = convertToCatalystExpression(filter, tableSchema)
+    if (expectExpression == null) {
+      assertEquals(exp.isEmpty, true)
+    } else {
+      assertEquals(exp.isDefined, true)
+      assertEquals(expectExpression, exp.get.sql)
+    }
+  }
+
+  private def checkConvertFilters(filters: Array[Filter], expectExpression: String): Unit = {
+    val exp = convertToCatalystExpressions(filters, tableSchema)
+    if (expectExpression == null) {
+      assertEquals(exp.isEmpty, true)
+    } else {
+      assertEquals(exp.isDefined, true)
+      assertEquals(expectExpression, exp.get.sql)
+    }
+  }
+
+  private def eq(attribute: String, value: Any): Filter = {
+    EqualTo(attribute, value)
+  }
+
+  private def eqs(attribute: String, value: Any): Filter = {
+    EqualNullSafe(attribute, value)
+  }
+
+  private def gt(attribute: String, value: Any): Filter = {
+    GreaterThan(attribute, value)
+  }
+
+  private def gte(attribute: String, value: Any): Filter = {
+    GreaterThanOrEqual(attribute, value)
+  }
+
+  private def lt(attribute: String, value: Any): Filter = {
+    LessThan(attribute, value)
+  }
+
+  private def lte(attribute: String, value: Any): Filter = {
+    LessThanOrEqual(attribute, value)
+  }
+
+  private def in(attribute: String, values: Any*): Filter = {
+    In(attribute, values.toArray)
+  }
+
+  private def isNull(attribute: String): Filter = {
+    IsNull(attribute)
+  }
+
+  private def isNotNull(attribute: String): Filter = {
+    IsNotNull(attribute)
+  }
+
+  private def and(left: Filter, right: Filter): Filter = {
+    And(left, right)
+  }
+
+  private def or(left: Filter, right: Filter): Filter = {
+    Or(left, right)
+  }
+
+  private def not(child: Filter): Filter = {
+    Not(child)
+  }
+
+  private def startWith(attribute: String, value: String): Filter = {
+    StringStartsWith(attribute, value)
+  }
+
+  private def endWith(attribute: String, value: String): Filter = {
+    StringEndsWith(attribute, value)
+  }
+
+  private def contains(attribute: String, value: String): Filter = {
+    StringContains(attribute, value)
+  }
+
+  private def unsupport(): Filter = {
+    UnSupportFilter("")
+  }
+
+  case class UnSupportFilter(value: Any) extends Filter {
+    override def references: Array[String] = Array.empty
+  }
+}
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
index 00c40ab..eba2ac2 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.functions._
 import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.ValueSource
+import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
 
 import scala.collection.JavaConversions._
 
@@ -614,4 +614,67 @@ class TestMORDataSource extends HoodieClientTestBase {
       .load(basePath)
     assertEquals(N + 1, hoodieIncViewDF1.count())
   }
+
+  @ParameterizedTest
+  @CsvSource(Array("true, false", "false, true", "false, false", "true, true"))
+  def testMORPartitionPrune(partitionEncode: Boolean, hiveStylePartition: Boolean): Unit = {
+    val partitions = Array("2021/03/01", "2021/03/02", "2021/03/03", "2021/03/04", "2021/03/05")
+    val newDataGen =  new HoodieTestDataGenerator(partitions)
+    val records1 = newDataGen.generateInsertsContainsAllPartitions("000", 100)
+    val inputDF1 = spark.read.json(spark.sparkContext.parallelize(recordsToStrings(records1), 2))
+
+    val partitionCounts = partitions.map(p => p -> records1.count(r => r.getPartitionPath == p)).toMap
+
+    inputDF1.write.format("hudi")
+      .options(commonOpts)
+      .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+      .option(DataSourceWriteOptions.TABLE_TYPE_OPT_KEY, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL)
+      .option(DataSourceWriteOptions.URL_ENCODE_PARTITIONING_OPT_KEY, partitionEncode)
+      .option(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING_OPT_KEY, hiveStylePartition)
+      .mode(SaveMode.Overwrite)
+      .save(basePath)
+
+    val count1 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition = '2021/03/01'")
+      .count()
+    assertEquals(partitionCounts("2021/03/01"), count1)
+
+    val count2 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition > '2021/03/01' and partition < '2021/03/03'")
+      .count()
+    assertEquals(partitionCounts("2021/03/02"), count2)
+
+    val count3 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition != '2021/03/01'")
+      .count()
+    assertEquals(records1.size() - partitionCounts("2021/03/01"), count3)
+
+    val count4 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition like '2021/03/03%'")
+      .count()
+    assertEquals(partitionCounts("2021/03/03"), count4)
+
+    val count5 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition like '%2021/03/%'")
+      .count()
+    assertEquals(records1.size(), count5)
+
+    val count6 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("partition = '2021/03/01' or partition = '2021/03/05'")
+      .count()
+    assertEquals(partitionCounts("2021/03/01") + partitionCounts("2021/03/05"), count6)
+
+    val count7 = spark.read.format("hudi")
+      .load(basePath)
+      .filter("substr(partition, 9, 10) = '03'")
+      .count()
+
+    assertEquals(partitionCounts("2021/03/03"), count7)
+  }
 }