You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2022/10/13 21:22:15 UTC

[iceberg] branch master updated: Spark 3.3: Split SparkScan and SparkBatch (#5934)

This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new b196453e1d Spark 3.3: Split SparkScan and SparkBatch (#5934)
b196453e1d is described below

commit b196453e1dac7b3de6cc0378598a249525b194f0
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Thu Oct 13 14:22:05 2022 -0700

    Spark 3.3: Split SparkScan and SparkBatch (#5934)
---
 .../apache/iceberg/spark/source/SparkBatch.java    | 55 +++++++++++++++-------
 .../org/apache/iceberg/spark/source/SparkScan.java | 13 +++--
 2 files changed, 48 insertions(+), 20 deletions(-)

diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
index 19a1dce3c9..bcfa70bcf2 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
@@ -19,6 +19,7 @@
 package org.apache.iceberg.spark.source;
 
 import java.util.List;
+import java.util.Objects;
 import org.apache.iceberg.CombinedScanTask;
 import org.apache.iceberg.FileFormat;
 import org.apache.iceberg.Schema;
@@ -31,27 +32,36 @@ import org.apache.iceberg.util.Tasks;
 import org.apache.iceberg.util.ThreadPools;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.connector.read.Batch;
 import org.apache.spark.sql.connector.read.InputPartition;
 import org.apache.spark.sql.connector.read.PartitionReaderFactory;
 
-abstract class SparkBatch implements Batch {
+class SparkBatch implements Batch {
 
   private final JavaSparkContext sparkContext;
   private final Table table;
   private final SparkReadConf readConf;
+  private final List<CombinedScanTask> taskGroups;
   private final Schema expectedSchema;
   private final boolean caseSensitive;
   private final boolean localityEnabled;
-
-  SparkBatch(SparkSession spark, Table table, SparkReadConf readConf, Schema expectedSchema) {
-    this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
+  private final int scanHashCode;
+
+  SparkBatch(
+      JavaSparkContext sparkContext,
+      Table table,
+      SparkReadConf readConf,
+      List<CombinedScanTask> taskGroups,
+      Schema expectedSchema,
+      int scanHashCode) {
+    this.sparkContext = sparkContext;
     this.table = table;
     this.readConf = readConf;
+    this.taskGroups = taskGroups;
     this.expectedSchema = expectedSchema;
     this.caseSensitive = readConf.caseSensitive();
     this.localityEnabled = readConf.localityEnabled();
+    this.scanHashCode = scanHashCode;
   }
 
   @Override
@@ -61,7 +71,7 @@ abstract class SparkBatch implements Batch {
         sparkContext.broadcast(SerializableTableWithSize.copyOf(table));
     String expectedSchemaString = SchemaParser.toJson(expectedSchema);
 
-    InputPartition[] partitions = new InputPartition[tasks().size()];
+    InputPartition[] partitions = new InputPartition[taskGroups.size()];
 
     Tasks.range(partitions.length)
         .stopOnFailure()
@@ -70,7 +80,7 @@ abstract class SparkBatch implements Batch {
             index ->
                 partitions[index] =
                     new SparkInputPartition(
-                        tasks().get(index),
+                        taskGroups.get(index),
                         tableBroadcast,
                         expectedSchemaString,
                         caseSensitive,
@@ -79,12 +89,6 @@ abstract class SparkBatch implements Batch {
     return partitions;
   }
 
-  protected abstract List<CombinedScanTask> tasks();
-
-  protected JavaSparkContext sparkContext() {
-    return sparkContext;
-  }
-
   @Override
   public PartitionReaderFactory createReaderFactory() {
     return new ReaderFactory(batchSize());
@@ -101,7 +105,7 @@ abstract class SparkBatch implements Batch {
   }
 
   private boolean parquetOnly() {
-    return tasks().stream()
+    return taskGroups.stream()
         .allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.PARQUET));
   }
 
@@ -115,18 +119,37 @@ abstract class SparkBatch implements Batch {
   }
 
   private boolean orcOnly() {
-    return tasks().stream()
+    return taskGroups.stream()
         .allMatch(task -> !task.isDataTask() && onlyFileFormat(task, FileFormat.ORC));
   }
 
   private boolean orcBatchReadsEnabled() {
     return readConf.orcVectorizationEnabled()
         && // vectorization enabled
-        tasks().stream().noneMatch(TableScanUtil::hasDeletes); // no delete files
+        taskGroups.stream().noneMatch(TableScanUtil::hasDeletes); // no delete files
   }
 
   private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) {
     return task.files().stream()
         .allMatch(fileScanTask -> fileScanTask.file().format().equals(fileFormat));
   }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+
+    SparkBatch that = (SparkBatch) o;
+    return table.name().equals(that.table.name()) && scanHashCode == that.scanHashCode;
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(table.name(), scanHashCode);
+  }
 }
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
index 56a9b63a87..7d89fc23bc 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
@@ -21,6 +21,7 @@ package org.apache.iceberg.spark.source;
 import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
+import org.apache.iceberg.CombinedScanTask;
 import org.apache.iceberg.ScanTaskGroup;
 import org.apache.iceberg.Schema;
 import org.apache.iceberg.Snapshot;
@@ -37,6 +38,7 @@ import org.apache.iceberg.spark.source.metrics.NumSplits;
 import org.apache.iceberg.spark.source.metrics.TaskNumDeletes;
 import org.apache.iceberg.spark.source.metrics.TaskNumSplits;
 import org.apache.iceberg.util.PropertyUtil;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.metric.CustomMetric;
@@ -54,9 +56,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-abstract class SparkScan extends SparkBatch implements Scan, SupportsReportStatistics {
+abstract class SparkScan implements Scan, SupportsReportStatistics {
   private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class);
 
+  private final JavaSparkContext sparkContext;
   private final Table table;
   private final SparkReadConf readConf;
   private final boolean caseSensitive;
@@ -73,10 +76,10 @@ abstract class SparkScan extends SparkBatch implements Scan, SupportsReportStati
       SparkReadConf readConf,
       Schema expectedSchema,
       List<Expression> filters) {
-    super(spark, table, readConf, expectedSchema);
 
     SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema);
 
+    this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
     this.table = table;
     this.readConf = readConf;
     this.caseSensitive = readConf.caseSensitive();
@@ -101,15 +104,17 @@ abstract class SparkScan extends SparkBatch implements Scan, SupportsReportStati
     return filterExpressions;
   }
 
+  protected abstract List<CombinedScanTask> tasks();
+
   @Override
   public Batch toBatch() {
-    return this;
+    return new SparkBatch(sparkContext, table, readConf, tasks(), expectedSchema, hashCode());
   }
 
   @Override
   public MicroBatchStream toMicroBatchStream(String checkpointLocation) {
     return new SparkMicroBatchStream(
-        sparkContext(), table, readConf, expectedSchema, checkpointLocation);
+        sparkContext, table, readConf, expectedSchema, checkpointLocation);
   }
 
   @Override