You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@paimon.apache.org by lz...@apache.org on 2023/07/05 03:59:58 UTC

[incubator-paimon] branch master updated: [spark] disable shuffle while writing unaware-bucket table (#1454)

This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f2b772ca [spark] disable shuffle while writing unaware-bucket table (#1454)
0f2b772ca is described below

commit 0f2b772ca322ffe16d9707bde6caea14e1b39e3b
Author: YeJunHao <41...@users.noreply.github.com>
AuthorDate: Wed Jul 5 11:59:55 2023 +0800

    [spark] disable shuffle while writing unaware-bucket table (#1454)
---
 .../paimon/spark/commands/PaimonCommand.scala      |   8 +-
 .../spark/commands/WriteIntoPaimonTable.scala      | 109 ++++++++++++++-------
 .../org/apache/paimon/spark/SparkWriteITCase.java  |  23 +++++
 3 files changed, 99 insertions(+), 41 deletions(-)

diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index ea951bb29..9e8c964b1 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -35,9 +35,11 @@ trait PaimonCommand {
 
   def getTable: Table
 
-  def isDynamicBucketTable: Boolean = {
-    getTable.isInstanceOf[FileStoreTable] &&
-    getTable.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
+  lazy val bucketMode: BucketMode = getTable match {
+    case fileStoreTable: FileStoreTable =>
+      fileStoreTable.bucketMode
+    case _ =>
+      BucketMode.FIXED
   }
 
   def deserializeCommitMessage(
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index 29f56776a..0a482fdad 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -23,12 +23,12 @@ import org.apache.paimon.index.PartitionIndex
 import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, SaveMode}
 import org.apache.paimon.spark.SparkRow
 import org.apache.paimon.spark.SparkUtils.createIOManager
-import org.apache.paimon.table.{FileStoreTable, Table}
+import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
 import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
 import org.apache.paimon.types.RowType
 
 import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.command.RunnableCommand
@@ -77,45 +77,58 @@ case class WriteIntoPaimonTable(_table: FileStoreTable, saveMode: SaveMode, data
     val toRow = withBucketDataEncoder.createSerializer()
     val fromRow = withBucketDataEncoder.createDeserializer()
 
-    val withAssignedBucket = if (isDynamicBucketTable) {
-      val partitioned = if (primaryKeyCols.nonEmpty) {
-        // Make sure that the records with the same bucket values is within a task.
-        withBucketCol.repartition(primaryKeyCols: _*)
-      } else {
-        withBucketCol
-      }
-      val numSparkPartitions = partitioned.rdd.getNumPartitions
-      val dynamicBucketProcessor =
-        DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
-      partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(withBucketDataEncoder)
-    } else {
-      val commonBucketProcessor = CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
-      withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(withBucketDataEncoder)
+    def repartitionByBucket(ds: Dataset[Row]) = {
+      ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
     }
 
-    val commitMessages =
-      withAssignedBucket
-        .toDF()
-        .repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
-        .mapPartitions {
-          iter =>
-            val write = writeBuilder.newWrite()
-            write.withIOManager(createIOManager)
-            try {
-              iter.foreach {
-                row =>
-                  val bucket = row.getInt(bucketColIdx)
-                  val bucketColDropped = originFromRow(toRow(row))
-                  write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
-              }
-              val serializer = new CommitMessageSerializer
-              write.prepareCommit().asScala.map(serializer.serialize).toIterator
-            } finally {
-              write.close()
+    val df =
+      bucketMode match {
+        case BucketMode.DYNAMIC =>
+          val partitioned = if (primaryKeyCols.nonEmpty) {
+            // Make sure that the records with the same bucket values is within a task.
+            withBucketCol.repartition(primaryKeyCols: _*)
+          } else {
+            withBucketCol
+          }
+          val numSparkPartitions = partitioned.rdd.getNumPartitions
+          val dynamicBucketProcessor =
+            DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
+          repartitionByBucket(
+            partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
+              withBucketDataEncoder))
+        case BucketMode.UNAWARE =>
+          val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
+          withBucketCol
+            .mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
+            .toDF()
+        case BucketMode.FIXED =>
+          val commonBucketProcessor =
+            CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
+          repartitionByBucket(
+            withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
+              withBucketDataEncoder))
+      }
+
+    val commitMessages = df
+      .mapPartitions {
+        iter =>
+          val write = writeBuilder.newWrite()
+          write.withIOManager(createIOManager)
+          try {
+            iter.foreach {
+              row =>
+                val bucket = row.getInt(bucketColIdx)
+                val bucketColDropped = originFromRow(toRow(row))
+                write.write(new DynamicBucketRow(new SparkRow(rowType, bucketColDropped), bucket))
             }
-        }
-        .collect()
-        .map(deserializeCommitMessage(serializer, _))
+            val serializer = new CommitMessageSerializer
+            write.prepareCommit().asScala.map(serializer.serialize).toIterator
+          } finally {
+            write.close()
+          }
+      }
+      .collect()
+      .map(deserializeCommitMessage(serializer, _))
 
     try {
       val tableCommit = if (overwritePartition == null) {
@@ -230,4 +243,24 @@ object WriteIntoPaimonTable {
       }
     }
   }
+
+  case class UnawareBucketProcessor(
+      bucketColIndex: Int,
+      toRow: ExpressionEncoder.Serializer[Row],
+      fromRow: ExpressionEncoder.Deserializer[Row])
+    extends BucketProcessor {
+
+    def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
+      new Iterator[Row] {
+        override def hasNext: Boolean = rowIterator.hasNext
+
+        override def next(): Row = {
+          val row = rowIterator.next
+          val sparkInternalRow = toRow(row)
+          sparkInternalRow.setInt(bucketColIndex, 0)
+          fromRow(sparkInternalRow)
+        }
+      }
+    }
+  }
 }
diff --git a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
index 401a56337..d5302fdee 100644
--- a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
+++ b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
@@ -157,6 +157,29 @@ public class SparkWriteITCase {
         assertThat(rows.toString()).isEqualTo("[[[1],2], [[2],0]]");
     }
 
+    @Test
+    public void testReadWriteUnawareBucketTable() {
+        spark.sql(
+                "CREATE TABLE T (a INT, b INT, c STRING) PARTITIONED BY (a) TBLPROPERTIES"
+                        + " ('write-mode'='append-only', 'bucket'='-1')");
+
+        spark.sql("INSERT INTO T VALUES (1, 1, '1'), (1, 2, '2')");
+        spark.sql("INSERT INTO T VALUES (1, 1, '1'), (1, 2, '2')");
+        spark.sql("INSERT INTO T VALUES (2, 1, '1'), (2, 2, '2')");
+        spark.sql("INSERT INTO T VALUES (2, 1, '1'), (2, 2, '2')");
+        spark.sql("INSERT INTO T VALUES (3, 1, '1'), (3, 2, '2')");
+        spark.sql("INSERT INTO T VALUES (3, 1, '1'), (3, 2, '2')");
+
+        List<Row> rows = spark.sql("SELECT count(1) FROM T").collectAsList();
+        assertThat(rows.toString()).isEqualTo("[[12]]");
+
+        rows = spark.sql("SELECT * FROM T WHERE b = 2 AND a = 1").collectAsList();
+        assertThat(rows.toString()).isEqualTo("[[1,2,2], [1,2,2]]");
+
+        rows = spark.sql("SELECT max(bucket) FROM `T$FILES`").collectAsList();
+        assertThat(rows.toString()).isEqualTo("[[0]]");
+    }
+
     @Test
     public void testNonnull() {
         try {