You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@paimon.apache.org by lz...@apache.org on 2023/07/05 03:59:58 UTC
[incubator-paimon] branch master updated: [spark] disable shuffle while writing unaware-bucket table (#1454)
This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 0f2b772ca [spark] disable shuffle while writing unaware-bucket table (#1454)
0f2b772ca is described below
commit 0f2b772ca322ffe16d9707bde6caea14e1b39e3b
Author: YeJunHao <41...@users.noreply.github.com>
AuthorDate: Wed Jul 5 11:59:55 2023 +0800
[spark] disable shuffle while writing unaware-bucket table (#1454)
---
.../paimon/spark/commands/PaimonCommand.scala | 8 +-
.../spark/commands/WriteIntoPaimonTable.scala | 109 ++++++++++++++-------
.../org/apache/paimon/spark/SparkWriteITCase.java | 23 +++++
3 files changed, 99 insertions(+), 41 deletions(-)
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index ea951bb29..9e8c964b1 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -35,9 +35,11 @@ trait PaimonCommand {
def getTable: Table
- def isDynamicBucketTable: Boolean = {
- getTable.isInstanceOf[FileStoreTable] &&
- getTable.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
+ lazy val bucketMode: BucketMode = getTable match {
+ case fileStoreTable: FileStoreTable =>
+ fileStoreTable.bucketMode
+ case _ =>
+ BucketMode.FIXED
}
def deserializeCommitMessage(
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index 29f56776a..0a482fdad 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -23,12 +23,12 @@ import org.apache.paimon.index.PartitionIndex
import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, SaveMode}
import org.apache.paimon.spark.SparkRow
import org.apache.paimon.spark.SparkUtils.createIOManager
-import org.apache.paimon.table.{FileStoreTable, Table}
+import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
import org.apache.paimon.types.RowType
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
@@ -77,45 +77,58 @@ case class WriteIntoPaimonTable(_table: FileStoreTable, saveMode: SaveMode, data
val toRow = withBucketDataEncoder.createSerializer()
val fromRow = withBucketDataEncoder.createDeserializer()
- val withAssignedBucket = if (isDynamicBucketTable) {
- val partitioned = if (primaryKeyCols.nonEmpty) {
- // Make sure that the records with the same bucket values is within a task.
- withBucketCol.repartition(primaryKeyCols: _*)
- } else {
- withBucketCol
- }
- val numSparkPartitions = partitioned.rdd.getNumPartitions
- val dynamicBucketProcessor =
- DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
- partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(withBucketDataEncoder)
- } else {
- val commonBucketProcessor = CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
- withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(withBucketDataEncoder)
+ def repartitionByBucket(ds: Dataset[Row]) = {
+ ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
}
- val commitMessages =
- withAssignedBucket
- .toDF()
- .repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
- .mapPartitions {
- iter =>
- val write = writeBuilder.newWrite()
- write.withIOManager(createIOManager)
- try {
- iter.foreach {
- row =>
- val bucket = row.getInt(bucketColIdx)
- val bucketColDropped = originFromRow(toRow(row))
- write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
- }
- val serializer = new CommitMessageSerializer
- write.prepareCommit().asScala.map(serializer.serialize).toIterator
- } finally {
- write.close()
+ val df =
+ bucketMode match {
+ case BucketMode.DYNAMIC =>
+ val partitioned = if (primaryKeyCols.nonEmpty) {
+ // Make sure that the records with the same bucket values is within a task.
+ withBucketCol.repartition(primaryKeyCols: _*)
+ } else {
+ withBucketCol
+ }
+ val numSparkPartitions = partitioned.rdd.getNumPartitions
+ val dynamicBucketProcessor =
+ DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
+ repartitionByBucket(
+ partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
+ withBucketDataEncoder))
+ case BucketMode.UNAWARE =>
+ val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
+ withBucketCol
+ .mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
+ .toDF()
+ case BucketMode.FIXED =>
+ val commonBucketProcessor =
+ CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
+ repartitionByBucket(
+ withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
+ withBucketDataEncoder))
+ }
+
+ val commitMessages = df
+ .mapPartitions {
+ iter =>
+ val write = writeBuilder.newWrite()
+ write.withIOManager(createIOManager)
+ try {
+ iter.foreach {
+ row =>
+ val bucket = row.getInt(bucketColIdx)
+ val bucketColDropped = originFromRow(toRow(row))
+ write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
}
- }
- .collect()
- .map(deserializeCommitMessage(serializer, _))
+ val serializer = new CommitMessageSerializer
+ write.prepareCommit().asScala.map(serializer.serialize).toIterator
+ } finally {
+ write.close()
+ }
+ }
+ .collect()
+ .map(deserializeCommitMessage(serializer, _))
try {
val tableCommit = if (overwritePartition == null) {
@@ -230,4 +243,24 @@ object WriteIntoPaimonTable {
}
}
}
+
+ case class UnawareBucketProcessor(
+ bucketColIndex: Int,
+ toRow: ExpressionEncoder.Serializer[Row],
+ fromRow: ExpressionEncoder.Deserializer[Row])
+ extends BucketProcessor {
+
+ def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
+ new Iterator[Row] {
+ override def hasNext: Boolean = rowIterator.hasNext
+
+ override def next(): Row = {
+ val row = rowIterator.next
+ val sparkInternalRow = toRow(row)
+ sparkInternalRow.setInt(bucketColIndex, 0)
+ fromRow(sparkInternalRow)
+ }
+ }
+ }
+ }
}
diff --git a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
index 401a56337..d5302fdee 100644
--- a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
+++ b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
@@ -157,6 +157,29 @@ public class SparkWriteITCase {
assertThat(rows.toString()).isEqualTo("[[[1],2], [[2],0]]");
}
+ @Test
+ public void testReadWriteUnawareBucketTable() {
+ spark.sql(
+ "CREATE TABLE T (a INT, b INT, c STRING) PARTITIONED BY (a) TBLPROPERTIES"
+ + " ('write-mode'='append-only', 'bucket'='-1')");
+
+ spark.sql("INSERT INTO T VALUES (1, 1, '1'), (1, 2, '2')");
+ spark.sql("INSERT INTO T VALUES (1, 1, '1'), (1, 2, '2')");
+ spark.sql("INSERT INTO T VALUES (2, 1, '1'), (2, 2, '2')");
+ spark.sql("INSERT INTO T VALUES (2, 1, '1'), (2, 2, '2')");
+ spark.sql("INSERT INTO T VALUES (3, 1, '1'), (3, 2, '2')");
+ spark.sql("INSERT INTO T VALUES (3, 1, '1'), (3, 2, '2')");
+
+ List<Row> rows = spark.sql("SELECT count(1) FROM T").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[12]]");
+
+ rows = spark.sql("SELECT * FROM T WHERE b = 2 AND a = 1").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[1,2,2], [1,2,2]]");
+
+ rows = spark.sql("SELECT max(bucket) FROM `T$FILES`").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[0]]");
+ }
+
@Test
public void testNonnull() {
try {