You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ru...@apache.org on 2021/10/18 22:03:25 UTC

[iceberg] branch master updated: Spark: Spark3 Sort Compaction Implementation (#2829)

This is an automated email from the ASF dual-hosted git repository.

russellspitzer pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 61fe9b1  Spark: Spark3 Sort Compaction Implementation (#2829)
61fe9b1 is described below

commit 61fe9b1ce3ecff4354bb7624c9b5e3f76df2221b
Author: Russell Spitzer <rs...@apple.com>
AuthorDate: Mon Oct 18 17:03:14 2021 -0500

    Spark: Spark3 Sort Compaction Implementation (#2829)
    
    * Spark: Adds Spark3 Sort Based Compaction
    
    Implements Spark3 Sort Based compaction. Uses similar logic to the
    Spark3BinPack Strategy but instead of doing a direct read then write,
    issues a read, sort, and then write.
---
 .../apache/iceberg/actions/RewriteDataFiles.java   |  18 ++
 .../apache/iceberg/actions/BinPackStrategy.java    |   4 +
 .../org/apache/iceberg/actions/SortStrategy.java   |  13 +
 .../org/apache/iceberg/util/SortOrderUtil.java     |   2 +-
 .../actions/BaseRewriteDataFilesSparkAction.java   |  61 +++-
 .../actions/TestNewRewriteDataFilesAction.java     | 313 ++++++++++++++++++---
 spark/v3.0/build.gradle                            |   7 +
 .../actions/BaseRewriteDataFilesSpark3Action.java  |   6 +
 .../iceberg/spark/actions/Spark3SortStrategy.java  | 158 +++++++++++
 .../utils/DistributionAndOrderingUtils.scala       |   5 +-
 10 files changed, 534 insertions(+), 53 deletions(-)

diff --git a/api/src/main/java/org/apache/iceberg/actions/RewriteDataFiles.java b/api/src/main/java/org/apache/iceberg/actions/RewriteDataFiles.java
index eab8846..ce587d1 100644
--- a/api/src/main/java/org/apache/iceberg/actions/RewriteDataFiles.java
+++ b/api/src/main/java/org/apache/iceberg/actions/RewriteDataFiles.java
@@ -20,6 +20,7 @@
 package org.apache.iceberg.actions;
 
 import java.util.List;
+import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.StructLike;
 import org.apache.iceberg.expressions.Expression;
 
@@ -85,6 +86,23 @@ public interface RewriteDataFiles extends SnapshotUpdate<RewriteDataFiles, Rewri
   }
 
   /**
+   * Choose SORT as a strategy for this rewrite operation using the table's sortOrder
+   * @return this for method chaining
+   */
+  default RewriteDataFiles sort() {
+    throw new UnsupportedOperationException("SORT Rewrite Strategy not implemented for this framework");
+  }
+
+  /**
+   * Choose SORT as a strategy for this rewrite operation and manually specify the sortOrder to use
+   * @param sortOrder user defined sortOrder
+   * @return this for method chaining
+   */
+  default RewriteDataFiles sort(SortOrder sortOrder) {
+    throw new UnsupportedOperationException("SORT Rewrite Strategy not implemented for this framework");
+  }
+
+  /**
    * A user provided filter for determining which files will be considered by the rewrite strategy. This will be used
    * in addition to whatever rules the rewrite strategy generates. For example this would be used for providing a
    * restriction to only run rewrite on a specific partition.
diff --git a/core/src/main/java/org/apache/iceberg/actions/BinPackStrategy.java b/core/src/main/java/org/apache/iceberg/actions/BinPackStrategy.java
index c3d1f7a..15dbb07 100644
--- a/core/src/main/java/org/apache/iceberg/actions/BinPackStrategy.java
+++ b/core/src/main/java/org/apache/iceberg/actions/BinPackStrategy.java
@@ -191,6 +191,10 @@ public abstract class BinPackStrategy implements RewriteStrategy {
     return fileToRewrite.stream().mapToLong(FileScanTask::length).sum();
   }
 
+  protected long maxGroupSize() {
+    return maxGroupSize;
+  }
+
   /**
    * Estimates a larger max target file size than our target size used in task creation to avoid
    * tasks which are predicted to have a certain size, but exceed that target size when serde is complete creating
diff --git a/core/src/main/java/org/apache/iceberg/actions/SortStrategy.java b/core/src/main/java/org/apache/iceberg/actions/SortStrategy.java
index 45f428e..bda632c 100644
--- a/core/src/main/java/org/apache/iceberg/actions/SortStrategy.java
+++ b/core/src/main/java/org/apache/iceberg/actions/SortStrategy.java
@@ -19,12 +19,15 @@
 
 package org.apache.iceberg.actions;
 
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.iceberg.FileScanTask;
 import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
+import org.apache.iceberg.util.BinPacking;
+import org.apache.iceberg.util.BinPacking.ListPacker;
 import org.apache.iceberg.util.PropertyUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -113,6 +116,16 @@ public abstract class SortStrategy extends BinPackStrategy {
     }
   }
 
+  @Override
+  public Iterable<List<FileScanTask>> planFileGroups(Iterable<FileScanTask> dataFiles) {
+    if (rewriteAll) {
+      ListPacker<FileScanTask> packer = new BinPacking.ListPacker<>(maxGroupSize(), 1, false);
+      return packer.pack(dataFiles, FileScanTask::length);
+    } else {
+      return super.planFileGroups(dataFiles);
+    }
+  }
+
   protected void validateOptions() {
     Preconditions.checkArgument(!sortOrder.isUnsorted(),
         "Can't use %s when there is no sort order, either define table %s's sort order or set sort" +
diff --git a/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java b/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java
index 54941e1..809d348 100644
--- a/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java
+++ b/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java
@@ -44,7 +44,7 @@ public class SortOrderUtil {
     return buildSortOrder(table.schema(), table.spec(), table.sortOrder());
   }
 
-  static SortOrder buildSortOrder(Schema schema, PartitionSpec spec, SortOrder sortOrder) {
+  public static SortOrder buildSortOrder(Schema schema, PartitionSpec spec, SortOrder sortOrder) {
     if (sortOrder.isUnsorted() && spec.isUnpartitioned()) {
       return SortOrder.unsorted();
     }
diff --git a/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java b/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java
index 7a6a7a9..406b2bb 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSparkAction.java
@@ -34,6 +34,7 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import org.apache.iceberg.DataFile;
 import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.StructLike;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.actions.BaseRewriteDataFilesFileGroupInfo;
@@ -43,6 +44,8 @@ import org.apache.iceberg.actions.RewriteDataFiles;
 import org.apache.iceberg.actions.RewriteDataFilesCommitManager;
 import org.apache.iceberg.actions.RewriteFileGroup;
 import org.apache.iceberg.actions.RewriteStrategy;
+import org.apache.iceberg.actions.SortStrategy;
+import org.apache.iceberg.data.GenericRecord;
 import org.apache.iceberg.exceptions.CommitFailedException;
 import org.apache.iceberg.exceptions.ValidationException;
 import org.apache.iceberg.expressions.Expression;
@@ -52,14 +55,16 @@ import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTest
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.relocated.com.google.common.collect.Queues;
 import org.apache.iceberg.relocated.com.google.common.collect.Sets;
-import org.apache.iceberg.relocated.com.google.common.collect.Streams;
 import org.apache.iceberg.relocated.com.google.common.math.IntMath;
 import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
 import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.iceberg.types.Types.StructType;
 import org.apache.iceberg.util.PropertyUtil;
+import org.apache.iceberg.util.StructLikeMap;
 import org.apache.iceberg.util.Tasks;
 import org.apache.spark.sql.SparkSession;
 import org.slf4j.Logger;
@@ -83,12 +88,11 @@ abstract class BaseRewriteDataFilesSparkAction
   private int maxConcurrentFileGroupRewrites;
   private int maxCommits;
   private boolean partialProgressEnabled;
-  private RewriteStrategy strategy;
+  private RewriteStrategy strategy = null;
 
   protected BaseRewriteDataFilesSparkAction(SparkSession spark, Table table) {
     super(spark);
     this.table = table;
-    this.strategy = binPackStrategy();
   }
 
   protected Table table() {
@@ -100,13 +104,36 @@ abstract class BaseRewriteDataFilesSparkAction
    */
   protected abstract BinPackStrategy binPackStrategy();
 
+  /**
+   * The framework specific {@link SortStrategy}
+   */
+  protected abstract SortStrategy sortStrategy();
+
   @Override
   public RewriteDataFiles binPack() {
+    Preconditions.checkArgument(this.strategy == null,
+        "Cannot set strategy to binpack, it has already been set", this.strategy);
     this.strategy = binPackStrategy();
     return this;
   }
 
   @Override
+  public RewriteDataFiles sort(SortOrder sortOrder) {
+    Preconditions.checkArgument(this.strategy == null,
+        "Cannot set strategy to sort, it has already been set to %s", this.strategy);
+    this.strategy = sortStrategy().sortOrder(sortOrder);
+    return this;
+  }
+
+  @Override
+  public RewriteDataFiles sort() {
+    Preconditions.checkArgument(this.strategy == null,
+        "Cannot set strategy to sort, it has already been set to %s", this.strategy);
+    this.strategy = sortStrategy();
+    return this;
+  }
+
+  @Override
   public RewriteDataFiles filter(Expression expression) {
     filter = Expressions.and(filter, expression);
     return this;
@@ -120,6 +147,11 @@ abstract class BaseRewriteDataFilesSparkAction
 
     long startingSnapshotId = table.currentSnapshot().snapshotId();
 
+    // Default to BinPack if no strategy selected
+    if (this.strategy == null) {
+      this.strategy = binPackStrategy();
+    }
+
     validateAndInitOptions();
     strategy = strategy.options(options());
 
@@ -149,10 +181,27 @@ abstract class BaseRewriteDataFilesSparkAction
         .planFiles();
 
     try {
-      Map<StructLike, List<FileScanTask>> filesByPartition = Streams.stream(fileScanTasks)
-          .collect(Collectors.groupingBy(task -> task.file().partition()));
+      StructType partitionType = table.spec().partitionType();
+      StructLikeMap<List<FileScanTask>> filesByPartition = StructLikeMap.create(partitionType);
+      StructLike emptyStruct = GenericRecord.create(partitionType);
+
+      fileScanTasks.forEach(task -> {
+        // If a task uses an incompatible partition spec the data inside could contain values which
+        // belong to multiple partitions in the current spec. Treating all such files as un-partitioned and
+        // grouping them together helps to minimize new files made.
+        StructLike taskPartition = task.file().specId() == table.spec().specId() ?
+            task.file().partition() : emptyStruct;
+
+        List<FileScanTask> files = filesByPartition.get(taskPartition);
+        if (files == null) {
+          files = Lists.newArrayList();
+        }
+
+        files.add(task);
+        filesByPartition.put(taskPartition, files);
+      });
 
-      Map<StructLike, List<List<FileScanTask>>> fileGroupsByPartition = Maps.newHashMap();
+      StructLikeMap<List<List<FileScanTask>>> fileGroupsByPartition = StructLikeMap.create(partitionType);
 
       filesByPartition.forEach((partition, tasks) -> {
         Iterable<FileScanTask> filtered = strategy.selectFilesToRewrite(tasks);
diff --git a/spark/src/test/java/org/apache/iceberg/spark/actions/TestNewRewriteDataFilesAction.java b/spark/src/test/java/org/apache/iceberg/spark/actions/TestNewRewriteDataFilesAction.java
index 8500065..21fa4a1 100644
--- a/spark/src/test/java/org/apache/iceberg/spark/actions/TestNewRewriteDataFilesAction.java
+++ b/spark/src/test/java/org/apache/iceberg/spark/actions/TestNewRewriteDataFilesAction.java
@@ -28,11 +28,15 @@ import java.util.Random;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import java.util.stream.Stream;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.DataFile;
 import org.apache.iceberg.FileScanTask;
 import org.apache.iceberg.PartitionSpec;
 import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.StructLike;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.TableProperties;
 import org.apache.iceberg.actions.ActionsProvider;
@@ -41,6 +45,7 @@ import org.apache.iceberg.actions.RewriteDataFiles;
 import org.apache.iceberg.actions.RewriteDataFiles.Result;
 import org.apache.iceberg.actions.RewriteDataFilesCommitManager;
 import org.apache.iceberg.actions.RewriteFileGroup;
+import org.apache.iceberg.actions.SortStrategy;
 import org.apache.iceberg.exceptions.CommitStateUnknownException;
 import org.apache.iceberg.expressions.Expressions;
 import org.apache.iceberg.hadoop.HadoopTables;
@@ -249,8 +254,8 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
 
     Result result = basicRewrite(table)
         .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(targetSize))
-        .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString(targetSize + 100))
-        .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 100))
+        .option(BinPackStrategy.MAX_FILE_SIZE_BYTES, Integer.toString(targetSize + 1000))
+        .option(BinPackStrategy.MIN_FILE_SIZE_BYTES, Integer.toString(targetSize - 1000))
         .execute();
 
     Assert.assertEquals("Action should delete 3 data files", 3, result.rewrittenDataFilesCount());
@@ -274,7 +279,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     RewriteDataFiles.Result result =
         basicRewrite(table)
             .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true")
-            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
             .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "10")
             .execute();
 
@@ -299,7 +304,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     // Perform a rewrite but only allow 2 files to be compacted at a time
     RewriteDataFiles.Result result =
         basicRewrite(table)
-            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
             .option(BinPackStrategy.MIN_INPUT_FILES, "1")
             .execute();
 
@@ -324,7 +329,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     // Perform a rewrite but only allow 2 files to be compacted at a time
     RewriteDataFiles.Result result =
         basicRewrite(table)
-            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
             .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true")
             .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3")
             .execute();
@@ -350,7 +355,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100));
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000));
 
     BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite);
 
@@ -383,7 +388,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100));
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000));
 
     BaseRewriteDataFilesSparkAction spyRewrite = spy(realRewrite);
     RewriteDataFilesCommitManager util = spy(new RewriteDataFilesCommitManager(table));
@@ -420,7 +425,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
                 .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3");
 
     BaseRewriteDataFilesSparkAction spyRewrite = Mockito.spy(realRewrite);
@@ -454,7 +459,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true")
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3");
 
@@ -492,7 +497,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
                 .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3")
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true")
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3");
@@ -531,7 +536,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     BaseRewriteDataFilesSparkAction realRewrite =
         (org.apache.iceberg.spark.actions.BaseRewriteDataFilesSparkAction)
             basicRewrite(table)
-                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 100))
+                .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
                 .option(RewriteDataFiles.MAX_CONCURRENT_FILE_GROUP_REWRITES, "3")
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_ENABLED, "true")
                 .option(RewriteDataFiles.PARTIAL_PROGRESS_MAX_COMMITS, "3");
@@ -591,6 +596,190 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
   }
 
   @Test
+  public void testSortMultipleGroups() {
+    Table table = createTable(20);
+    shouldHaveFiles(table, 20);
+    table.replaceSortOrder().asc("c2").commit();
+    shouldHaveLastCommitUnsorted(table, "c2");
+    int fileSize = averageFileSize(table);
+
+    List<Object[]> originalData = currentData();
+
+    // Perform a rewrite but only allow 2 files to be compacted at a time
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort()
+            .option(SortStrategy.REWRITE_ALL, "true")
+            .option(RewriteDataFiles.MAX_FILE_GROUP_SIZE_BYTES, Integer.toString(fileSize * 2 + 1000))
+            .execute();
+
+    Assert.assertEquals("Should have 10 fileGroups", result.rewriteResults().size(), 10);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+  }
+
+  @Test
+  public void testSimpleSort() {
+    Table table = createTable(20);
+    shouldHaveFiles(table, 20);
+    table.replaceSortOrder().asc("c2").commit();
+    shouldHaveLastCommitUnsorted(table, "c2");
+
+    List<Object[]> originalData = currentData();
+
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort()
+            .option(SortStrategy.MIN_INPUT_FILES, "1")
+            .option(SortStrategy.REWRITE_ALL, "true")
+            .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table)))
+            .execute();
+
+    Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+    shouldHaveMultipleFiles(table);
+    shouldHaveLastCommitSorted(table, "c2");
+  }
+
+  @Test
+  public void testSortAfterPartitionChange() {
+    Table table = createTable(20);
+    shouldHaveFiles(table, 20);
+    table.updateSpec().addField(Expressions.bucket("c1", 4)).commit();
+    table.replaceSortOrder().asc("c2").commit();
+    shouldHaveLastCommitUnsorted(table, "c2");
+
+    List<Object[]> originalData = currentData();
+
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort()
+            .option(SortStrategy.MIN_INPUT_FILES, "1")
+            .option(SortStrategy.REWRITE_ALL, "true")
+            .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table)))
+            .execute();
+
+    Assert.assertEquals("Should have 1 fileGroup because all files were not correctly partitioned",
+        result.rewriteResults().size(), 1);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+    shouldHaveMultipleFiles(table);
+    shouldHaveLastCommitSorted(table, "c2");
+  }
+
+  @Test
+  public void testSortCustomSortOrder() {
+    Table table = createTable(20);
+    shouldHaveLastCommitUnsorted(table, "c2");
+    shouldHaveFiles(table, 20);
+
+    List<Object[]> originalData = currentData();
+
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort(SortOrder.builderFor(table.schema()).asc("c2").build())
+            .option(SortStrategy.REWRITE_ALL, "true")
+            .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table)))
+            .execute();
+
+    Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+    shouldHaveMultipleFiles(table);
+    shouldHaveLastCommitSorted(table, "c2");
+  }
+
+  @Test
+  public void testSortCustomSortOrderRequiresRepartition() {
+    Table table = createTable(20);
+    shouldHaveLastCommitUnsorted(table, "c3");
+
+    // Add a partition column so this requires repartitioning
+    table.updateSpec().addField("c1").commit();
+    // Add a sort order which our repartitioning needs to ignore
+    table.replaceSortOrder().asc("c2").apply();
+    shouldHaveFiles(table, 20);
+
+    List<Object[]> originalData = currentData();
+
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort(SortOrder.builderFor(table.schema()).asc("c3").build())
+            .option(SortStrategy.REWRITE_ALL, "true")
+            .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table)))
+            .execute();
+
+    Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+    shouldHaveMultipleFiles(table);
+    shouldHaveLastCommitUnsorted(table, "c2");
+    shouldHaveLastCommitSorted(table, "c3");
+  }
+
+  @Test
+  public void testAutoSortShuffleOutput() {
+    Table table = createTable(20);
+    shouldHaveLastCommitUnsorted(table, "c2");
+    shouldHaveFiles(table, 20);
+
+    List<Object[]> originalData = currentData();
+
+    RewriteDataFiles.Result result =
+        basicRewrite(table)
+            .sort(SortOrder.builderFor(table.schema()).asc("c2").build())
+            .option(SortStrategy.MAX_FILE_SIZE_BYTES, Integer.toString((averageFileSize(table) / 2) + 2))
+            // Divide files in 2
+            .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, Integer.toString(averageFileSize(table) / 2))
+            .option(SortStrategy.MIN_INPUT_FILES, "1")
+            .execute();
+
+    Assert.assertEquals("Should have 1 fileGroups", result.rewriteResults().size(), 1);
+    Assert.assertTrue("Should have written 40+ files", Iterables.size(table.currentSnapshot().addedFiles()) >= 40);
+
+    table.refresh();
+
+    List<Object[]> postRewriteData = currentData();
+    assertEquals("We shouldn't have changed the data", originalData, postRewriteData);
+
+    shouldHaveSnapshots(table, 2);
+    shouldHaveACleanCache(table);
+    shouldHaveMultipleFiles(table);
+    shouldHaveLastCommitSorted(table, "c2");
+  }
+
+  @Test
   public void testCommitStateUnknownException() {
     Table table = createTable(20);
     shouldHaveFiles(table, 20);
@@ -619,6 +808,20 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     shouldHaveSnapshots(table, 2); // Commit actually Succeeded
   }
 
+  @Test
+  public void testInvalidAPIUsage() {
+    Table table = createTable(1);
+
+    AssertHelpers.assertThrows("Should be unable to set Strategy more than once", IllegalArgumentException.class,
+        "Cannot set strategy", () -> actions().rewriteDataFiles(table).binPack().sort());
+
+    AssertHelpers.assertThrows("Should be unable to set Strategy more than once", IllegalArgumentException.class,
+        "Cannot set strategy", () -> actions().rewriteDataFiles(table).sort().binPack());
+
+    AssertHelpers.assertThrows("Should be unable to set Strategy more than once", IllegalArgumentException.class,
+        "Cannot set strategy", () -> actions().rewriteDataFiles(table).sort(SortOrder.unsorted()).binPack());
+  }
+
   protected List<Object[]> currentData() {
     return rowsToJava(spark.read().format("iceberg").load(tableLocation)
         .sort("c1", "c2", "c3")
@@ -629,6 +832,12 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     return Streams.stream(table.newScan().planFiles()).mapToLong(FileScanTask::length).sum();
   }
 
+  protected void shouldHaveMultipleFiles(Table table) {
+    table.refresh();
+    int numFiles = Iterables.size(table.newScan().planFiles());
+    Assert.assertTrue(String.format("Should have multiple files, had %d", numFiles), numFiles > 1);
+  }
+
   protected void shouldHaveFiles(Table table, int numExpected) {
     table.refresh();
     int numFiles = Iterables.size(table.newScan().planFiles());
@@ -666,7 +875,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     List<Pair<Pair<T, T>, Pair<T, T>>>
         overlappingFiles = getOverlappingFiles(table, column);
 
-    Assert.assertNotEquals("Found overlapping files", Collections.emptyList(), overlappingFiles);
+    Assert.assertNotEquals("Found no overlapping files", Collections.emptyList(), overlappingFiles);
   }
 
   private <T> List<Pair<Pair<T, T>, Pair<T, T>>> getOverlappingFiles(Table table, String column) {
@@ -674,38 +883,46 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     NestedField field = table.schema().caseInsensitiveFindField(column);
     int columnId = field.fieldId();
     Class<T> javaClass = (Class<T>) field.type().typeId().javaClass();
-    List<Pair<T, T>> columnBounds =
-        Streams.stream(table.currentSnapshot().addedFiles())
-            .map(file -> Pair.of(
-                javaClass.cast(Conversions.fromByteBuffer(field.type(), file.lowerBounds().get(columnId))),
-                javaClass.cast(Conversions.fromByteBuffer(field.type(), file.upperBounds().get(columnId)))))
-            .collect(Collectors.toList());
-
-    Comparator<T> comparator = Comparators.forType(field.type().asPrimitiveType());
-
-    List<Pair<Pair<T, T>, Pair<T, T>>> overlappingFiles = columnBounds.stream()
-        .flatMap(left -> columnBounds.stream().map(right -> Pair.of(left, right)))
-        .filter(filePair -> {
-          Pair<T, T> left = filePair.first();
-          T leftLower = left.first();
-          T leftUpper = left.second();
-          Pair<T, T> right = filePair.second();
-          T rightLower = right.first();
-          T rightUpper = right.second();
-          boolean boundsOverlap =
-              (comparator.compare(leftUpper, rightLower) > 0 && comparator.compare(leftUpper, rightUpper) < 0) ||
-                  (comparator.compare(leftLower, rightLower) > 0 && comparator.compare(leftLower, rightUpper) < 0);
-
-          return (left != right) && boundsOverlap;
-        })
-        .collect(Collectors.toList());
-    return overlappingFiles;
+    Map<StructLike, List<DataFile>> filesByPartition = Streams.stream(table.currentSnapshot().addedFiles())
+        .collect(Collectors.groupingBy(DataFile::partition));
+
+    Stream<Pair<Pair<T, T>, Pair<T, T>>> overlaps =
+        filesByPartition.entrySet().stream().flatMap(entry -> {
+          List<Pair<T, T>> columnBounds =
+              entry.getValue().stream()
+                  .map(file -> Pair.of(
+                      javaClass.cast(Conversions.fromByteBuffer(field.type(), file.lowerBounds().get(columnId))),
+                      javaClass.cast(Conversions.fromByteBuffer(field.type(), file.upperBounds().get(columnId)))))
+                  .collect(Collectors.toList());
+
+          Comparator<T> comparator = Comparators.forType(field.type().asPrimitiveType());
+
+          List<Pair<Pair<T, T>, Pair<T, T>>> overlappingFiles = columnBounds.stream()
+              .flatMap(left -> columnBounds.stream().map(right -> Pair.of(left, right)))
+              .filter(filePair -> {
+                Pair<T, T> left = filePair.first();
+                T lLower = left.first();
+                T lUpper = left.second();
+                Pair<T, T> right = filePair.second();
+                T rLower = right.first();
+                T rUpper = right.second();
+                boolean boundsOverlap =
+                    (comparator.compare(lUpper, rLower) >= 0 && comparator.compare(lUpper, rUpper) <= 0) ||
+                        (comparator.compare(lLower, rLower) >= 0 && comparator.compare(lLower, rUpper) <= 0);
+
+                return (left != right) && boundsOverlap;
+              })
+              .collect(Collectors.toList());
+          return overlappingFiles.stream();
+        });
+
+    return overlaps.collect(Collectors.toList());
   }
 
   /**
    * Create a table with a certain number of files, returns the size of a file
    * @param files number of files to create
-   * @return size of a file
+   * @return the created table
    */
   protected Table createTable(int files) {
     PartitionSpec spec = PartitionSpec.unpartitioned();
@@ -714,7 +931,7 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
     table.updateProperties().set(TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, "1024").commit();
     Assert.assertNull("Table must be empty", table.currentSnapshot());
 
-    writeRecords(files, 2000);
+    writeRecords(files, 40000);
 
     return table;
   }
@@ -743,12 +960,22 @@ public abstract class TestNewRewriteDataFilesAction extends SparkTestBase {
 
   private void writeRecords(int files, int numRecords, int partitions) {
     List<ThreeColumnRecord> records = Lists.newArrayList();
-    List<Integer> data = IntStream.range(0, numRecords).boxed().collect(Collectors.toList());
+    int rowDimension = (int) Math.ceil(Math.sqrt(numRecords));
+    List<Pair<Integer, Integer>> data =
+        IntStream.range(0, rowDimension).boxed().flatMap(x ->
+                IntStream.range(0, rowDimension).boxed().map(y -> Pair.of(x, y)))
+            .collect(Collectors.toList());
     Collections.shuffle(data, new Random(42));
     if (partitions > 0) {
-      data.forEach(i -> records.add(new ThreeColumnRecord(i % partitions, "foo" + i, "bar" + i)));
+      data.forEach(i -> records.add(new ThreeColumnRecord(
+          i.first() % partitions,
+          "foo" + i.first(),
+          "bar" + i.second())));
     } else {
-      data.forEach(i -> records.add(new ThreeColumnRecord(i, "foo" + i, "bar" + i)));
+      data.forEach(i -> records.add(new ThreeColumnRecord(
+          i.first(),
+          "foo" + i.first(),
+          "bar" + i.second())));
     }
     Dataset<Row> df = spark.createDataFrame(records, ThreeColumnRecord.class).repartition(files);
     writeDF(df);
diff --git a/spark/v3.0/build.gradle b/spark/v3.0/build.gradle
index 480f6ad..43dc13d 100644
--- a/spark/v3.0/build.gradle
+++ b/spark/v3.0/build.gradle
@@ -20,6 +20,13 @@
 project(':iceberg-spark:iceberg-spark3') {
   apply plugin: 'scala'
 
+  sourceSets {
+    main {
+      scala.srcDirs = ['src/main/scala', 'src/main/java']
+      java.srcDirs = []
+    }
+  }
+
   dependencies {
     implementation project(path: ':iceberg-bundled-guava', configuration: 'shadow')
     api project(':iceberg-api')
diff --git a/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSpark3Action.java b/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSpark3Action.java
index ac2224f..b1c08e6 100644
--- a/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSpark3Action.java
+++ b/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/BaseRewriteDataFilesSpark3Action.java
@@ -22,6 +22,7 @@ package org.apache.iceberg.spark.actions;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.actions.BinPackStrategy;
 import org.apache.iceberg.actions.RewriteDataFiles;
+import org.apache.iceberg.actions.SortStrategy;
 import org.apache.spark.sql.SparkSession;
 
 public class BaseRewriteDataFilesSpark3Action extends BaseRewriteDataFilesSparkAction {
@@ -36,6 +37,11 @@ public class BaseRewriteDataFilesSpark3Action extends BaseRewriteDataFilesSparkA
   }
 
   @Override
+  protected SortStrategy sortStrategy() {
+    return new Spark3SortStrategy(table(), spark());
+  }
+
+  @Override
   protected RewriteDataFiles self() {
     return this;
   }
diff --git a/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/Spark3SortStrategy.java b/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/Spark3SortStrategy.java
new file mode 100644
index 0000000..2931796
--- /dev/null
+++ b/spark/v3.0/spark3/src/main/java/org/apache/iceberg/spark/actions/Spark3SortStrategy.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.actions;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.actions.RewriteDataFiles;
+import org.apache.iceberg.actions.RewriteStrategy;
+import org.apache.iceberg.actions.SortStrategy;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
+import org.apache.iceberg.spark.FileRewriteCoordinator;
+import org.apache.iceberg.spark.FileScanTaskSetManager;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkReadOptions;
+import org.apache.iceberg.spark.SparkWriteOptions;
+import org.apache.iceberg.util.PropertyUtil;
+import org.apache.iceberg.util.SortOrderUtil;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils$;
+import org.apache.spark.sql.connector.iceberg.distributions.Distribution;
+import org.apache.spark.sql.connector.iceberg.distributions.Distributions;
+import org.apache.spark.sql.connector.iceberg.expressions.SortOrder;
+import org.apache.spark.sql.internal.SQLConf;
+
+public class Spark3SortStrategy extends SortStrategy {
+
+  /**
+   * The number of shuffle partitions and consequently the number of output files
+   * created by the Spark Sort is based on the size of the input data files used
+   * in this rewrite operation. Due to compression, the disk file sizes may not
+   * accurately represent the size of files in the output. This parameter lets
+   * the user adjust the file size used for estimating actual output data size. A
+   * factor greater than 1.0 would generate more files than we would expect based
+   * on the on-disk file size. A value less than 1.0 would create fewer files than
+   * we would expect due to the on-disk size.
+   */
+  public static final String COMPRESSION_FACTOR = "compression-factor";
+
+  private final Table table;
+  private final SparkSession spark;
+  private final FileScanTaskSetManager manager = FileScanTaskSetManager.get();
+  private final FileRewriteCoordinator rewriteCoordinator = FileRewriteCoordinator.get();
+
+  private double sizeEstimateMultiple;
+
+  public Spark3SortStrategy(Table table, SparkSession spark) {
+    this.table = table;
+    this.spark = spark;
+  }
+
+  @Override
+  public Table table() {
+    return table;
+  }
+
+  @Override
+  public Set<String> validOptions() {
+    return ImmutableSet.<String>builder()
+        .addAll(super.validOptions())
+        .add(COMPRESSION_FACTOR)
+        .build();
+  }
+
+  @Override
+  public RewriteStrategy options(Map<String, String> options) {
+    sizeEstimateMultiple = PropertyUtil.propertyAsDouble(options,
+        COMPRESSION_FACTOR,
+        1.0);
+
+    Preconditions.checkArgument(sizeEstimateMultiple > 0,
+        "Invalid compression factor: %s (not positive)", sizeEstimateMultiple);
+
+    return super.options(options);
+  }
+
+  @Override
+  public Set<DataFile> rewriteFiles(List<FileScanTask> filesToRewrite) {
+    String groupID = UUID.randomUUID().toString();
+    boolean requiresRepartition = !filesToRewrite.get(0).spec().equals(table.spec());
+
+    SortOrder[] ordering;
+    if (requiresRepartition) {
+      // Build in the requirement for Partition Sorting into our sort order
+      ordering = Spark3Util.convert(SortOrderUtil.buildSortOrder(table.schema(), table.spec(), sortOrder()));
+    } else {
+      ordering = Spark3Util.convert(sortOrder());
+    }
+
+    Distribution distribution = Distributions.ordered(ordering);
+
+    try {
+      manager.stageTasks(table, groupID, filesToRewrite);
+
+      // Disable Adaptive Query Execution as this may change the output partitioning of our write
+      SparkSession cloneSession = spark.cloneSession();
+      cloneSession.conf().set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), false);
+
+      // Reset Shuffle Partitions for our sort
+      long numOutputFiles = numOutputFiles((long) (inputFileSize(filesToRewrite) * sizeEstimateMultiple));
+      cloneSession.conf().set(SQLConf.SHUFFLE_PARTITIONS().key(), Math.max(1, numOutputFiles));
+
+      Dataset<Row> scanDF = cloneSession.read().format("iceberg")
+          .option(SparkReadOptions.FILE_SCAN_TASK_SET_ID, groupID)
+          .load(table.name());
+
+      // write the packed data into new files where each split becomes a new file
+      SQLConf sqlConf = cloneSession.sessionState().conf();
+      LogicalPlan sortPlan = sortPlan(distribution, ordering, scanDF.logicalPlan(), sqlConf);
+      Dataset<Row> sortedDf = new Dataset<>(cloneSession, sortPlan, scanDF.encoder());
+
+      sortedDf.write()
+          .format("iceberg")
+          .option(SparkWriteOptions.REWRITTEN_FILE_SCAN_TASK_SET_ID, groupID)
+          .option(RewriteDataFiles.TARGET_FILE_SIZE_BYTES, writeMaxFileSize())
+          .mode("append") // This will only write files without modifying the table, see SparkWrite.RewriteFiles
+          .save(table.name());
+
+      return rewriteCoordinator.fetchNewDataFiles(table, groupID);
+    } finally {
+      manager.removeTasks(table, groupID);
+      rewriteCoordinator.clearRewrite(table, groupID);
+    }
+  }
+
+  protected SparkSession spark() {
+    return this.spark;
+  }
+
+  protected LogicalPlan sortPlan(Distribution distribution, SortOrder[] ordering, LogicalPlan plan, SQLConf conf) {
+    return DistributionAndOrderingUtils$.MODULE$.prepareQuery(distribution, ordering, plan, conf);
+  }
+}
diff --git a/spark/v3.0/spark3/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala b/spark/v3.0/spark3/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala
index 6488779..fb1f758 100644
--- a/spark/v3.0/spark3/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala
+++ b/spark/v3.0/spark3/src/main/scala/org/apache/spark/sql/catalyst/utils/DistributionAndOrderingUtils.scala
@@ -61,7 +61,7 @@ object DistributionAndOrderingUtils {
 
   def prepareQuery(
       requiredDistribution: Distribution,
-      requiredOrdering: Seq[SortOrder],
+      requiredOrdering: Array[SortOrder],
       query: LogicalPlan,
       conf: SQLConf): LogicalPlan = {
 
@@ -87,8 +87,7 @@ object DistributionAndOrderingUtils {
     }
 
     val ordering = requiredOrdering
-      .map(e => toCatalyst(e, query, resolver))
-      .asInstanceOf[Seq[catalyst.expressions.SortOrder]]
+      .map(e => toCatalyst(e, query, resolver).asInstanceOf[catalyst.expressions.SortOrder])
 
     val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
       Sort(ordering, global = false, queryWithDistribution)