You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2021/10/22 19:45:58 UTC
[iceberg] branch master updated: Spark: Initial support for 3.2
(#3335)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new f5a7537 Spark: Initial support for 3.2 (#3335)
f5a7537 is described below
commit f5a753791f4dc6aca78569a14f731feda9edf462
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Fri Oct 22 12:45:43 2021 -0700
Spark: Initial support for 3.2 (#3335)
---
.github/workflows/spark-ci.yml | 27 +
.gitignore | 2 +
build.gradle | 10 -
gradle.properties | 4 +-
jmh.gradle | 4 +
settings.gradle | 12 +
spark/build.gradle | 4 +
spark/v2.4/build.gradle | 16 +
spark/v3.0/build.gradle | 27 +-
spark/{v3.0 => v3.2}/build.gradle | 55 +-
.../IcebergSqlExtensions.g4 | 293 ++++
.../extensions/IcebergSparkSessionExtensions.scala | 41 +
.../analysis/ProcedureArgumentCoercion.scala | 56 +
.../sql/catalyst/analysis/ResolveProcedures.scala | 190 +++
.../IcebergSparkSqlExtensionsParser.scala | 295 ++++
.../IcebergSqlExtensionsAstBuilder.scala | 286 ++++
.../plans/logical/AddPartitionField.scala} | 18 +-
.../spark/sql/catalyst/plans/logical/Call.scala} | 19 +-
.../plans/logical/DropIdentifierFields.scala} | 19 +-
.../plans/logical/DropPartitionField.scala} | 18 +-
.../plans/logical/ReplacePartitionField.scala} | 23 +-
.../plans/logical/SetIdentifierFields.scala} | 20 +-
.../logical/SetWriteDistributionAndOrdering.scala | 44 +
.../sql/catalyst/plans/logical/statements.scala} | 29 +-
.../datasources/v2/AddPartitionFieldExec.scala | 56 +
.../sql/execution/datasources/v2/CallExec.scala} | 24 +-
.../datasources/v2/DropIdentifierFieldsExec.scala | 65 +
.../datasources/v2/DropPartitionFieldExec.scala | 67 +
.../v2/ExtendedDataSourceV2Strategy.scala | 93 ++
.../datasources/v2/ReplacePartitionFieldExec.scala | 72 +
.../datasources/v2/SetIdentifierFieldsExec.scala | 52 +
.../v2/SetWriteDistributionAndOrderingExec.scala | 78 +
.../apache/iceberg/spark/extensions/Employee.java | 68 +
.../spark/extensions/SparkExtensionsTestBase.java | 61 +
.../SparkRowLevelOperationsTestBase.java | 206 +++
.../spark/extensions/TestAddFilesProcedure.java | 738 ++++++++++
.../extensions/TestAlterTablePartitionFields.java | 415 ++++++
.../spark/extensions/TestAlterTableSchema.java | 142 ++
.../spark/extensions/TestCallStatementParser.java | 169 +++
.../TestCherrypickSnapshotProcedure.java | 180 +++
.../spark/extensions/TestCopyOnWriteDelete.java} | 24 +-
.../spark/extensions/TestCopyOnWriteMerge.java} | 24 +-
.../spark/extensions/TestCopyOnWriteUpdate.java} | 24 +-
.../iceberg/spark/extensions/TestDelete.java | 744 ++++++++++
.../extensions/TestExpireSnapshotsProcedure.java | 227 +++
.../spark/extensions/TestIcebergExpressions.java | 70 +
.../apache/iceberg/spark/extensions/TestMerge.java | 1492 ++++++++++++++++++++
.../extensions/TestMigrateTableProcedure.java | 151 ++
.../extensions/TestRemoveOrphanFilesProcedure.java | 235 +++
.../extensions/TestRewriteManifestsProcedure.java | 174 +++
.../TestRollbackToSnapshotProcedure.java | 260 ++++
.../TestRollbackToTimestampProcedure.java | 268 ++++
.../TestSetCurrentSnapshotProcedure.java | 221 +++
.../TestSetWriteDistributionAndOrdering.java | 282 ++++
.../extensions/TestSnapshotTableProcedure.java | 191 +++
.../iceberg/spark/extensions/TestUpdate.java | 899 ++++++++++++
spark/v3.2/spark-runtime/LICENSE | 682 +++++++++
spark/v3.2/spark-runtime/NOTICE | 508 +++++++
.../java/org/apache/iceberg/spark/SmokeTest.java | 160 +++
.../spark/benchmark/.gitkeep} | 11 -
.../apache/iceberg/spark/SparkBenchmarkUtil.java | 58 +
.../SparkParquetReadersFlatDataBenchmark.java | 216 +++
.../SparkParquetReadersNestedDataBenchmark.java | 216 +++
.../SparkParquetWritersFlatDataBenchmark.java | 124 ++
.../SparkParquetWritersNestedDataBenchmark.java | 123 ++
.../org/apache/iceberg/spark/source/Action.java} | 12 +-
.../spark/source/IcebergSourceBenchmark.java | 190 +++
.../source/IcebergSourceFlatDataBenchmark.java | 59 +
.../source/IcebergSourceNestedDataBenchmark.java | 58 +
.../IcebergSourceNestedListDataBenchmark.java | 56 +
.../iceberg/spark/source/WritersBenchmark.java | 353 +++++
.../spark/source/avro/AvroWritersBenchmark.java} | 26 +-
.../IcebergSourceFlatAvroDataReadBenchmark.java | 132 ++
.../IcebergSourceNestedAvroDataReadBenchmark.java | 132 ++
.../orc/IcebergSourceFlatORCDataBenchmark.java | 67 +
.../orc/IcebergSourceFlatORCDataReadBenchmark.java | 184 +++
...ebergSourceNestedListORCDataWriteBenchmark.java | 99 ++
.../IcebergSourceNestedORCDataReadBenchmark.java | 161 +++
...cebergSourceFlatParquetDataFilterBenchmark.java | 122 ++
.../IcebergSourceFlatParquetDataReadBenchmark.java | 153 ++
...IcebergSourceFlatParquetDataWriteBenchmark.java | 90 ++
...gSourceNestedListParquetDataWriteBenchmark.java | 89 ++
...bergSourceNestedParquetDataFilterBenchmark.java | 121 ++
...cebergSourceNestedParquetDataReadBenchmark.java | 154 ++
...ebergSourceNestedParquetDataWriteBenchmark.java | 90 ++
.../source/parquet/ParquetWritersBenchmark.java} | 26 +-
...dDictionaryEncodedFlatParquetDataBenchmark.java | 139 ++
.../VectorizedReadFlatParquetDataBenchmark.java | 304 ++++
.../java/org/apache/iceberg/actions/Actions.java | 210 +++
.../org/apache/iceberg/actions/CreateAction.java | 46 +
.../iceberg/actions/ExpireSnapshotsAction.java | 135 ++
.../actions/ExpireSnapshotsActionResult.java | 54 +
.../apache/iceberg/actions/ManifestFileBean.java | 144 ++
.../iceberg/actions/RemoveOrphanFilesAction.java | 90 ++
.../iceberg/actions/RewriteDataFilesAction.java | 70 +
.../iceberg/actions/RewriteManifestsAction.java | 86 ++
.../actions/RewriteManifestsActionResult.java | 57 +
.../apache/iceberg/actions/SnapshotAction.java} | 15 +-
.../iceberg/actions/Spark3MigrateAction.java | 62 +
.../iceberg/actions/Spark3SnapshotAction.java | 70 +
.../org/apache/iceberg/actions/SparkActions.java | 66 +
.../java/org/apache/iceberg/spark/BaseCatalog.java | 48 +
.../iceberg/spark/FileRewriteCoordinator.java | 96 ++
.../iceberg/spark/FileScanTaskSetManager.java | 78 +
.../org/apache/iceberg/spark/IcebergSpark.java} | 24 +-
.../org/apache/iceberg/spark/JobGroupInfo.java} | 33 +-
.../org/apache/iceberg/spark/JobGroupUtils.java | 46 +
.../java/org/apache/iceberg/spark/OrderField.java | 102 ++
.../org/apache/iceberg/spark/PathIdentifier.java | 57 +
.../iceberg/spark/PruneColumnsWithReordering.java | 258 ++++
.../spark/PruneColumnsWithoutReordering.java | 226 +++
.../apache/iceberg/spark/RollbackStagedTable.java | 137 ++
.../org/apache/iceberg/spark/SortOrderToSpark.java | 62 +
.../java/org/apache/iceberg/spark/Spark3Util.java | 921 ++++++++++++
.../org/apache/iceberg/spark/SparkCatalog.java | 470 ++++++
.../org/apache/iceberg/spark/SparkConfParser.java | 186 +++
.../org/apache/iceberg/spark/SparkDataFile.java | 203 +++
.../apache/iceberg/spark/SparkExceptionUtil.java | 66 +
.../org/apache/iceberg/spark/SparkFilters.java | 218 +++
.../iceberg/spark/SparkFixupTimestampType.java | 57 +
.../org/apache/iceberg/spark/SparkFixupTypes.java | 64 +
.../org/apache/iceberg/spark/SparkReadConf.java | 177 +++
.../org/apache/iceberg/spark/SparkReadOptions.java | 68 +
.../apache/iceberg/spark/SparkSQLProperties.java | 46 +
.../org/apache/iceberg/spark/SparkSchemaUtil.java | 302 ++++
.../apache/iceberg/spark/SparkSessionCatalog.java | 299 ++++
.../org/apache/iceberg/spark/SparkStructLike.java | 55 +
.../org/apache/iceberg/spark/SparkTableUtil.java | 695 +++++++++
.../org/apache/iceberg/spark/SparkTypeToType.java | 163 +++
.../org/apache/iceberg/spark/SparkTypeVisitor.java | 83 ++
.../java/org/apache/iceberg/spark/SparkUtil.java | 182 +++
.../apache/iceberg/spark/SparkValueConverter.java | 120 ++
.../org/apache/iceberg/spark/SparkWriteConf.java | 140 ++
.../apache/iceberg/spark/SparkWriteOptions.java | 56 +
.../org/apache/iceberg/spark/TypeToSparkType.java | 125 ++
.../actions/BaseDeleteOrphanFilesSparkAction.java | 264 ++++
.../BaseDeleteReachableFilesSparkAction.java | 204 +++
.../actions/BaseExpireSnapshotsSparkAction.java | 276 ++++
.../spark/actions/BaseMigrateTableSparkAction.java | 214 +++
.../actions/BaseRewriteDataFilesSpark3Action.java} | 33 +-
.../actions/BaseRewriteDataFilesSparkAction.java | 436 ++++++
.../actions/BaseRewriteManifestsSparkAction.java | 364 +++++
.../actions/BaseSnapshotTableSparkAction.java | 212 +++
.../actions/BaseSnapshotUpdateSparkAction.java} | 31 +-
.../iceberg/spark/actions/BaseSparkAction.java | 168 +++
.../iceberg/spark/actions/BaseSparkActions.java | 61 +
.../actions/BaseTableCreationSparkAction.java | 168 +++
.../spark/actions/Spark3BinPackStrategy.java | 84 ++
.../iceberg/spark/actions/Spark3SortStrategy.java | 158 +++
.../apache/iceberg/spark/actions/SparkActions.java | 72 +
.../spark/data/AvroWithSparkSchemaVisitor.java | 74 +
.../spark/data/ParquetWithSparkSchemaVisitor.java | 208 +++
.../apache/iceberg/spark/data/SparkAvroReader.java | 167 +++
.../apache/iceberg/spark/data/SparkAvroWriter.java | 156 ++
.../apache/iceberg/spark/data/SparkOrcReader.java | 123 ++
.../iceberg/spark/data/SparkOrcValueReaders.java | 232 +++
.../iceberg/spark/data/SparkOrcValueWriters.java | 201 +++
.../apache/iceberg/spark/data/SparkOrcWriter.java | 222 +++
.../iceberg/spark/data/SparkParquetReaders.java | 756 ++++++++++
.../iceberg/spark/data/SparkParquetWriters.java | 423 ++++++
.../iceberg/spark/data/SparkValueReaders.java | 288 ++++
.../iceberg/spark/data/SparkValueWriters.java | 252 ++++
.../vectorized/ArrowVectorAccessorFactory.java | 110 ++
.../data/vectorized/ArrowVectorAccessors.java} | 25 +-
.../spark/data/vectorized/ColumnarBatchReader.java | 64 +
.../data/vectorized/ConstantColumnVector.java | 124 ++
.../data/vectorized/IcebergArrowColumnVector.java | 157 ++
.../data/vectorized/RowPositionColumnVector.java | 122 ++
.../data/vectorized/VectorizedSparkOrcReaders.java | 427 ++++++
.../vectorized/VectorizedSparkParquetReaders.java | 52 +
.../spark/procedures/AddFilesProcedure.java | 246 ++++
.../iceberg/spark/procedures/BaseProcedure.java | 154 ++
.../procedures/CherrypickSnapshotProcedure.java | 97 ++
.../spark/procedures/ExpireSnapshotsProcedure.java | 137 ++
.../spark/procedures/MigrateTableProcedure.java | 94 ++
.../procedures/RemoveOrphanFilesProcedure.java | 142 ++
.../procedures/RewriteManifestsProcedure.java | 107 ++
.../procedures/RollbackToSnapshotProcedure.java | 96 ++
.../procedures/RollbackToTimestampProcedure.java | 100 ++
.../procedures/SetCurrentSnapshotProcedure.java | 97 ++
.../spark/procedures/SnapshotTableProcedure.java | 109 ++
.../iceberg/spark/procedures/SparkProcedures.java | 61 +
.../iceberg/spark/source/BaseDataReader.java | 199 +++
.../iceberg/spark/source/BatchDataReader.java | 117 ++
.../spark/source/EqualityDeleteRowReader.java | 54 +
.../apache/iceberg/spark/source/IcebergSource.java | 148 ++
.../iceberg/spark/source/InternalRowWrapper.java | 96 ++
.../apache/iceberg/spark/source/RowDataReader.java | 199 +++
.../iceberg/spark/source/RowDataRewriter.java | 153 ++
.../iceberg/spark/source/SparkAppenderFactory.java | 273 ++++
.../iceberg/spark/source/SparkBatchQueryScan.java | 166 +++
.../iceberg/spark/source/SparkBatchScan.java | 330 +++++
.../spark/source/SparkFileWriterFactory.java | 245 ++++
.../iceberg/spark/source/SparkFilesScan.java | 106 ++
.../spark/source/SparkFilesScanBuilder.java | 49 +
.../iceberg/spark/source/SparkMergeBuilder.java | 110 ++
.../iceberg/spark/source/SparkMergeScan.java | 179 +++
.../spark/source/SparkMicroBatchStream.java | 264 ++++
.../spark/source/SparkPartitionedFanoutWriter.java | 51 +
.../spark/source/SparkPartitionedWriter.java | 51 +
.../iceberg/spark/source/SparkRewriteBuilder.java | 69 +
.../iceberg/spark/source/SparkScanBuilder.java | 174 +++
.../apache/iceberg/spark/source/SparkTable.java | 286 ++++
.../apache/iceberg/spark/source/SparkWrite.java | 653 +++++++++
.../iceberg/spark/source/SparkWriteBuilder.java | 153 ++
.../iceberg/spark/source/StagedSparkTable.java} | 27 +-
.../org/apache/iceberg/spark/source/Stats.java} | 28 +-
.../iceberg/spark/source/StreamingOffset.java | 154 ++
.../iceberg/spark/source/StructInternalRow.java | 336 +++++
.../analysis/NoSuchProcedureException.java} | 16 +-
.../iceberg/catalog/ExtendedSupportsDelete.java | 43 +
.../sql/connector/iceberg/catalog/Procedure.java | 60 +
.../iceberg/catalog/ProcedureCatalog.java | 41 +
.../iceberg/catalog/ProcedureParameter.java | 65 +
.../iceberg/catalog/ProcedureParameterImpl.java | 77 +
.../connector/iceberg/catalog/SupportsMerge.java | 41 +
.../distributions/ClusteredDistribution.java} | 22 +-
.../iceberg/distributions/Distribution.java} | 17 +-
.../iceberg/distributions/Distributions.java | 61 +
.../distributions/OrderedDistribution.java} | 22 +-
.../distributions/UnspecifiedDistribution.java} | 17 +-
.../impl/ClusterDistributionImpl.java} | 22 +-
.../impl/OrderedDistributionImpl.java} | 22 +-
.../impl/UnspecifiedDistributionImpl.java} | 12 +-
.../iceberg/expressions/NullOrdering.java} | 29 +-
.../iceberg/expressions/SortDirection.java} | 29 +-
.../connector/iceberg/expressions/SortOrder.java} | 31 +-
.../iceberg/read/SupportsFileFilter.java} | 21 +-
.../sql/connector/iceberg/write/MergeBuilder.java} | 27 +-
...org.apache.spark.sql.sources.DataSourceRegister | 20 +
.../expressions/TransformExpressions.scala | 159 +++
.../utils/DistributionAndOrderingUtils.scala | 188 +++
.../spark/sql/catalyst/utils/PlanUtils.scala} | 27 +-
.../test/java/org/apache/iceberg/KryoHelpers.java | 52 +
.../java/org/apache/iceberg/TaskCheckHelper.java | 93 ++
.../apache/iceberg/TestDataFileSerialization.java | 170 +++
.../apache/iceberg/TestFileIOSerialization.java | 113 ++
.../iceberg/TestManifestFileSerialization.java | 194 +++
.../apache/iceberg/TestScanTaskSerialization.java | 145 ++
.../org/apache/iceberg/TestTableSerialization.java | 102 ++
.../apache/iceberg/actions/TestCreateActions.java | 651 +++++++++
.../actions/TestDeleteReachableFilesAction.java | 349 +++++
.../iceberg/actions/TestExpireSnapshotsAction.java | 1062 ++++++++++++++
.../actions/TestRemoveOrphanFilesAction.java | 626 ++++++++
.../actions/TestRemoveOrphanFilesAction3.java | 169 +++
.../actions/TestRewriteDataFilesAction.java | 479 +++++++
.../actions/TestRewriteManifestsAction.java | 421 ++++++
.../apache/iceberg/spark/SparkCatalogTestBase.java | 109 ++
.../org/apache/iceberg/spark/SparkTestBase.java | 206 +++
.../iceberg/spark/TestFileRewriteCoordinator.java | 252 ++++
.../org/apache/iceberg/spark/TestSpark3Util.java | 108 ++
.../org/apache/iceberg/spark/TestSparkFilters.java | 67 +
.../apache/iceberg/spark/TestSparkSchemaUtil.java | 55 +
.../actions/TestNewRewriteDataFilesAction.java | 1017 +++++++++++++
.../apache/iceberg/spark/data/AvroDataTest.java | 240 ++++
.../apache/iceberg/spark/data/GenericsHelpers.java | 311 ++++
.../org/apache/iceberg/spark/data/RandomData.java | 363 +++++
.../org/apache/iceberg/spark/data/TestHelpers.java | 696 +++++++++
.../apache/iceberg/spark/data/TestOrcWrite.java | 61 +
.../iceberg/spark/data/TestParquetAvroReader.java | 232 +++
.../iceberg/spark/data/TestParquetAvroWriter.java | 107 ++
.../iceberg/spark/data/TestSparkAvroEnums.java | 95 ++
.../iceberg/spark/data/TestSparkAvroReader.java | 66 +
.../iceberg/spark/data/TestSparkDateTimes.java | 77 +
.../data/TestSparkOrcReadMetadataColumns.java | 209 +++
.../iceberg/spark/data/TestSparkOrcReader.java | 107 ++
.../data/TestSparkParquetReadMetadataColumns.java | 232 +++
.../iceberg/spark/data/TestSparkParquetReader.java | 203 +++
.../iceberg/spark/data/TestSparkParquetWriter.java | 100 ++
.../spark/data/TestSparkRecordOrcReaderWriter.java | 148 ++
...estParquetDictionaryEncodedVectorizedReads.java | 92 ++
...naryFallbackToPlainEncodingVectorizedReads.java | 74 +
.../vectorized/TestParquetVectorizedReads.java | 309 ++++
.../apache/iceberg/spark/source/LogMessage.java | 120 ++
.../apache/iceberg/spark/source/SimpleRecord.java | 80 ++
.../iceberg/spark/source/SparkTestTable.java | 60 +
.../apache/iceberg/spark/source/TestAvroScan.java | 115 ++
.../iceberg/spark/source/TestDataFrameWrites.java | 385 +++++
.../spark/source/TestDataSourceOptions.java | 406 ++++++
.../iceberg/spark/source/TestFilteredScan.java | 567 ++++++++
.../spark/source/TestForwardCompatibility.java | 217 +++
.../iceberg/spark/source/TestIcebergSource.java} | 29 +-
.../source/TestIcebergSourceHadoopTables.java | 68 +
.../spark/source/TestIcebergSourceHiveTables.java | 75 +
.../spark/source/TestIcebergSourceTablesBase.java | 1183 ++++++++++++++++
.../iceberg/spark/source/TestIcebergSpark.java | 70 +
.../spark/source/TestIdentityPartitionData.java | 185 +++
.../spark/source/TestInternalRowWrapper.java | 78 +
.../TestMetadataTablesWithPartitionEvolution.java | 343 +++++
.../iceberg/spark/source/TestParquetScan.java | 139 ++
.../iceberg/spark/source/TestPartitionPruning.java | 415 ++++++
.../iceberg/spark/source/TestPartitionValues.java | 434 ++++++
.../iceberg/spark/source/TestPathIdentifier.java | 88 ++
.../iceberg/spark/source/TestReadProjection.java | 578 ++++++++
.../spark/source/TestSnapshotSelection.java | 238 ++++
.../spark/source/TestSparkAppenderFactory.java | 70 +
.../spark/source/TestSparkBaseDataReader.java | 278 ++++
.../iceberg/spark/source/TestSparkCatalog.java | 43 +
.../source/TestSparkCatalogHadoopOverrides.java | 133 ++
.../iceberg/spark/source/TestSparkDataFile.java | 216 +++
.../iceberg/spark/source/TestSparkDataWrite.java | 600 ++++++++
.../spark/source/TestSparkFileWriterFactory.java | 73 +
.../iceberg/spark/source/TestSparkFilesScan.java | 121 ++
.../spark/source/TestSparkMergingMetrics.java | 68 +
.../spark/source/TestSparkMetadataColumns.java | 194 +++
.../spark/source/TestSparkPartitioningWriters.java | 73 +
.../source/TestSparkPositionDeltaWriters.java | 73 +
.../spark/source/TestSparkReadProjection.java | 243 ++++
.../spark/source/TestSparkReaderDeletes.java | 218 +++
.../spark/source/TestSparkRollingFileWriters.java | 59 +
.../iceberg/spark/source/TestSparkTable.java | 61 +
.../spark/source/TestSparkWriterMetrics.java | 59 +
.../iceberg/spark/source/TestStreamingOffset.java | 54 +
.../spark/source/TestStructuredStreaming.java | 305 ++++
.../spark/source/TestStructuredStreamingRead3.java | 431 ++++++
.../apache/iceberg/spark/source/TestTables.java | 206 +++
.../spark/source/TestTimestampWithoutZone.java | 290 ++++
.../spark/source/TestWriteMetricsConfig.java | 306 ++++
.../iceberg/spark/source/ThreeColumnRecord.java | 89 ++
.../apache/iceberg/spark/sql/TestAlterTable.java | 228 +++
.../apache/iceberg/spark/sql/TestCreateTable.java | 282 ++++
.../iceberg/spark/sql/TestCreateTableAsSelect.java | 314 ++++
.../apache/iceberg/spark/sql/TestDeleteFrom.java | 59 +
.../apache/iceberg/spark/sql/TestNamespaceSQL.java | 221 +++
.../iceberg/spark/sql/TestPartitionedWrites.java | 179 +++
.../apache/iceberg/spark/sql/TestRefreshTable.java | 76 +
.../org/apache/iceberg/spark/sql/TestSelect.java | 123 ++
.../spark/sql/TestTimestampWithoutZone.java | 196 +++
.../iceberg/spark/sql/TestUnpartitionedWrites.java | 155 ++
329 files changed, 57098 insertions(+), 307 deletions(-)
diff --git a/.github/workflows/spark-ci.yml b/.github/workflows/spark-ci.yml
index 59ce195..28ee236 100644
--- a/.github/workflows/spark-ci.yml
+++ b/.github/workflows/spark-ci.yml
@@ -98,3 +98,30 @@ jobs:
name: test logs
path: |
**/build/testlogs
+
+ spark-32-tests:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ jvm: [8, 11]
+ spark: ['3.2']
+ env:
+ SPARK_LOCAL_IP: localhost
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/setup-java@v1
+ with:
+ java-version: ${{ matrix.jvm }}
+ - uses: actions/cache@v2
+ with:
+ path: ~/.gradle/caches
+ key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }}
+ restore-keys: ${{ runner.os }}-gradle
+ - run: echo -e "$(ip addr show eth0 | grep "inet\b" | awk '{print $2}' | cut -d/ -f1)\t$(hostname -f) $(hostname -s)" | sudo tee -a /etc/hosts
+ - run: ./gradlew -DsparkVersions=${{ matrix.spark }} -DhiveVersions= -DflinkVersions= :iceberg-spark:iceberg-spark-3.2:check :iceberg-spark:iceberg-spark-3.2-extensions:check :iceberg-spark:iceberg-spark-3.2-runtime:check -Pquick=true -x javadoc
+ - uses: actions/upload-artifact@v2
+ if: failure()
+ with:
+ name: test logs
+ path: |
+ **/build/testlogs
diff --git a/.gitignore b/.gitignore
index d8e6dd5..8cbb17a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,6 +30,8 @@ spark/v2.4/spark2/benchmark/*
!spark/v2.4/spark2/benchmark/.gitkeep
spark/v3.0/spark3/benchmark/*
!spark/v3.0/spark3/benchmark/.gitkeep
+spark/v3.2/spark/benchmark/*
+!spark/v3.2/spark/benchmark/.gitkeep
__pycache__/
*.py[cod]
diff --git a/build.gradle b/build.gradle
index e537d8c..c1f74a2 100644
--- a/build.gradle
+++ b/build.gradle
@@ -64,10 +64,6 @@ allprojects {
mavenCentral()
mavenLocal()
}
- project.ext {
- Spark30Version = '3.0.3'
- Spark31Version = '3.1.1'
- }
}
subprojects {
@@ -92,12 +88,6 @@ subprojects {
exclude group: 'com.sun.jersey'
exclude group: 'com.sun.jersey.contribs'
exclude group: 'org.pentaho', module: 'pentaho-aggdesigner-algorithm'
-
- resolutionStrategy {
- force 'com.fasterxml.jackson.module:jackson-module-scala_2.11:2.11.4'
- force 'com.fasterxml.jackson.module:jackson-module-scala_2.12:2.11.4'
- force 'com.fasterxml.jackson.module:jackson-module-paranamer:2.11.4'
- }
}
testArtifacts
diff --git a/gradle.properties b/gradle.properties
index 43fdba6..d7ed4a2 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -19,7 +19,7 @@ systemProp.defaultFlinkVersions=1.13
systemProp.knownFlinkVersions=1.13
systemProp.defaultHiveVersions=2,3
systemProp.knownHiveVersions=2,3
-systemProp.defaultSparkVersions=2.4,3.0
-systemProp.knownSparkVersions=2.4,3.0
+systemProp.defaultSparkVersions=2.4,3.0,3.2
+systemProp.knownSparkVersions=2.4,3.0,3.2
org.gradle.parallel=true
org.gradle.jvmargs=-Xmx768m
diff --git a/jmh.gradle b/jmh.gradle
index 24a7878..50f73b0 100644
--- a/jmh.gradle
+++ b/jmh.gradle
@@ -32,6 +32,10 @@ if (sparkVersions.contains("3.0")) {
jmhProjects.add(project(":iceberg-spark:iceberg-spark3"))
}
+if (sparkVersions.contains("3.2")) {
+ jmhProjects.add(project(":iceberg-spark:iceberg-spark-3.2"))
+}
+
configure(jmhProjects) {
apply plugin: 'me.champeau.gradle.jmh'
diff --git a/settings.gradle b/settings.gradle
index ec05cf6..b70db46 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -91,6 +91,18 @@ if (sparkVersions.contains("3.0")) {
project(':iceberg-spark:spark3-runtime').name = 'iceberg-spark3-runtime'
}
+if (sparkVersions.contains("3.2")) {
+ include ':iceberg-spark:spark-3.2'
+ include ':iceberg-spark:spark-3.2-extensions'
+ include ':iceberg-spark:spark-3.2-runtime'
+ project(':iceberg-spark:spark-3.2').projectDir = file('spark/v3.2/spark')
+ project(':iceberg-spark:spark-3.2').name = 'iceberg-spark-3.2'
+ project(':iceberg-spark:spark-3.2-extensions').projectDir = file('spark/v3.2/spark-extensions')
+ project(':iceberg-spark:spark-3.2-extensions').name = 'iceberg-spark-3.2-extensions'
+ project(':iceberg-spark:spark-3.2-runtime').projectDir = file('spark/v3.2/spark-runtime')
+ project(':iceberg-spark:spark-3.2-runtime').name = 'iceberg-spark-3.2-runtime'
+}
+
// hive 3 depends on hive 2, so always add hive 2 if hive3 is enabled
if (hiveVersions.contains("2") || hiveVersions.contains("3")) {
include 'mr'
diff --git a/spark/build.gradle b/spark/build.gradle
index 30ea7fe..9ba32f2 100644
--- a/spark/build.gradle
+++ b/spark/build.gradle
@@ -27,3 +27,7 @@ if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
if (sparkVersions.contains("3.0")) {
apply from: file("$projectDir/v3.0/build.gradle")
}
+
+if (sparkVersions.contains("3.2")) {
+ apply from: file("$projectDir/v3.2/build.gradle")
+}
\ No newline at end of file
diff --git a/spark/v2.4/build.gradle b/spark/v2.4/build.gradle
index 4fbbf73..0789f12 100644
--- a/spark/v2.4/build.gradle
+++ b/spark/v2.4/build.gradle
@@ -21,6 +21,22 @@ if (jdkVersion != '8') {
throw new GradleException("Spark 2.4 must be built with Java 8")
}
+def sparkProjects = [
+ project(':iceberg-spark:iceberg-spark2'),
+ project(':iceberg-spark:iceberg-spark-runtime')
+]
+
+configure(sparkProjects) {
+ configurations {
+ all {
+ resolutionStrategy {
+ force 'com.fasterxml.jackson.module:jackson-module-scala_2.11:2.11.4'
+ force 'com.fasterxml.jackson.module:jackson-module-paranamer:2.11.4'
+ }
+ }
+ }
+}
+
project(':iceberg-spark:iceberg-spark2') {
configurations.all {
resolutionStrategy {
diff --git a/spark/v3.0/build.gradle b/spark/v3.0/build.gradle
index ec0b7c4..c337036 100644
--- a/spark/v3.0/build.gradle
+++ b/spark/v3.0/build.gradle
@@ -17,6 +17,27 @@
* under the License.
*/
+def sparkProjects = [
+ project(':iceberg-spark:iceberg-spark3'),
+ project(":iceberg-spark:iceberg-spark3-extensions"),
+ project(':iceberg-spark:iceberg-spark3-runtime')
+]
+
+configure(sparkProjects) {
+ project.ext {
+ sparkVersion = '3.0.3'
+ }
+
+ configurations {
+ all {
+ resolutionStrategy {
+ force 'com.fasterxml.jackson.module:jackson-module-scala_2.12:2.11.4'
+ force 'com.fasterxml.jackson.module:jackson-module-paranamer:2.11.4'
+ }
+ }
+ }
+}
+
project(':iceberg-spark:iceberg-spark3') {
apply plugin: 'scala'
@@ -40,7 +61,7 @@ project(':iceberg-spark:iceberg-spark3') {
compileOnly "com.google.errorprone:error_prone_annotations"
compileOnly "org.apache.avro:avro"
- compileOnly("org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}") {
+ compileOnly("org.apache.spark:spark-hive_2.12:${sparkVersion}") {
exclude group: 'org.apache.avro', module: 'avro'
exclude group: 'org.apache.arrow'
}
@@ -109,7 +130,7 @@ project(":iceberg-spark:iceberg-spark3-extensions") {
compileOnly project(':iceberg-spark')
compileOnly project(':iceberg-spark:iceberg-spark3')
compileOnly project(':iceberg-hive-metastore')
- compileOnly("org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}") {
+ compileOnly("org.apache.spark:spark-hive_2.12:${sparkVersion}") {
exclude group: 'org.apache.avro', module: 'avro'
exclude group: 'org.apache.arrow'
}
@@ -176,7 +197,7 @@ project(':iceberg-spark:iceberg-spark3-runtime') {
exclude group: 'com.google.code.findbugs', module: 'jsr305'
}
- integrationImplementation "org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}"
+ integrationImplementation "org.apache.spark:spark-hive_2.12:${sparkVersion}"
integrationImplementation 'org.junit.vintage:junit-vintage-engine'
integrationImplementation 'org.slf4j:slf4j-simple'
integrationImplementation project(path: ':iceberg-api', configuration: 'testArtifacts')
diff --git a/spark/v3.0/build.gradle b/spark/v3.2/build.gradle
similarity index 88%
copy from spark/v3.0/build.gradle
copy to spark/v3.2/build.gradle
index ec0b7c4..19fa5a6 100644
--- a/spark/v3.0/build.gradle
+++ b/spark/v3.2/build.gradle
@@ -17,7 +17,28 @@
* under the License.
*/
-project(':iceberg-spark:iceberg-spark3') {
+def sparkProjects = [
+ project(':iceberg-spark:iceberg-spark-3.2'),
+ project(":iceberg-spark:iceberg-spark-3.2-extensions"),
+ project(':iceberg-spark:iceberg-spark-3.2-runtime')
+]
+
+configure(sparkProjects) {
+ project.ext {
+ sparkVersion = '3.2.0'
+ }
+
+ configurations {
+ all {
+ resolutionStrategy {
+ force 'com.fasterxml.jackson.module:jackson-module-scala_2.12:2.12.3'
+ force 'com.fasterxml.jackson.module:jackson-module-paranamer:2.12.3'
+ }
+ }
+ }
+}
+
+project(':iceberg-spark:iceberg-spark-3.2') {
apply plugin: 'scala'
sourceSets {
@@ -40,7 +61,7 @@ project(':iceberg-spark:iceberg-spark3') {
compileOnly "com.google.errorprone:error_prone_annotations"
compileOnly "org.apache.avro:avro"
- compileOnly("org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}") {
+ compileOnly("org.apache.spark:spark-hive_2.12:${sparkVersion}") {
exclude group: 'org.apache.avro', module: 'avro'
exclude group: 'org.apache.arrow'
}
@@ -82,7 +103,7 @@ project(':iceberg-spark:iceberg-spark3') {
}
}
-project(":iceberg-spark:iceberg-spark3-extensions") {
+project(":iceberg-spark:iceberg-spark-3.2-extensions") {
apply plugin: 'java-library'
apply plugin: 'scala'
apply plugin: 'antlr'
@@ -106,10 +127,9 @@ project(":iceberg-spark:iceberg-spark3-extensions") {
compileOnly project(':iceberg-api')
compileOnly project(':iceberg-core')
compileOnly project(':iceberg-common')
- compileOnly project(':iceberg-spark')
- compileOnly project(':iceberg-spark:iceberg-spark3')
+ compileOnly project(':iceberg-spark:iceberg-spark-3.2')
compileOnly project(':iceberg-hive-metastore')
- compileOnly("org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}") {
+ compileOnly("org.apache.spark:spark-hive_2.12:${sparkVersion}") {
exclude group: 'org.apache.avro', module: 'avro'
exclude group: 'org.apache.arrow'
}
@@ -118,8 +138,7 @@ project(":iceberg-spark:iceberg-spark3-extensions") {
testImplementation project(path: ':iceberg-api', configuration: 'testArtifacts')
testImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts')
- testImplementation project(path: ':iceberg-spark', configuration: 'testArtifacts')
- testImplementation project(path: ':iceberg-spark:iceberg-spark3', configuration: 'testArtifacts')
+ testImplementation project(path: ':iceberg-spark:iceberg-spark-3.2', configuration: 'testArtifacts')
testImplementation "org.apache.avro:avro"
@@ -135,7 +154,7 @@ project(":iceberg-spark:iceberg-spark3-extensions") {
}
}
-project(':iceberg-spark:iceberg-spark3-runtime') {
+project(':iceberg-spark:iceberg-spark-3.2-runtime') {
apply plugin: 'com.github.johnrengelman.shadow'
tasks.jar.dependsOn tasks.shadowJar
@@ -169,24 +188,23 @@ project(':iceberg-spark:iceberg-spark3-runtime') {
dependencies {
api project(':iceberg-api')
- implementation project(':iceberg-spark:iceberg-spark3')
- implementation project(':iceberg-spark:iceberg-spark3-extensions')
+ implementation project(':iceberg-spark:iceberg-spark-3.2')
+ implementation project(':iceberg-spark:iceberg-spark-3.2-extensions')
implementation project(':iceberg-aws')
implementation(project(':iceberg-nessie')) {
exclude group: 'com.google.code.findbugs', module: 'jsr305'
}
- integrationImplementation "org.apache.spark:spark-hive_2.12:${project.ext.Spark30Version}"
+ integrationImplementation "org.apache.spark:spark-hive_2.12:${sparkVersion}"
integrationImplementation 'org.junit.vintage:junit-vintage-engine'
integrationImplementation 'org.slf4j:slf4j-simple'
integrationImplementation project(path: ':iceberg-api', configuration: 'testArtifacts')
integrationImplementation project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts')
- integrationImplementation project(path: ':iceberg-spark', configuration: 'testArtifacts')
- integrationImplementation project(path: ':iceberg-spark:iceberg-spark3', configuration: 'testArtifacts')
- integrationImplementation project(path: ':iceberg-spark:iceberg-spark3-extensions', configuration: 'testArtifacts')
+ integrationImplementation project(path: ':iceberg-spark:iceberg-spark-3.2', configuration: 'testArtifacts')
+ integrationImplementation project(path: ':iceberg-spark:iceberg-spark-3.2-extensions', configuration: 'testArtifacts')
// Not allowed on our classpath, only the runtime jar is allowed
- integrationCompileOnly project(':iceberg-spark:iceberg-spark3-extensions')
- integrationCompileOnly project(':iceberg-spark:iceberg-spark3')
+ integrationCompileOnly project(':iceberg-spark:iceberg-spark-3.2-extensions')
+ integrationCompileOnly project(':iceberg-spark:iceberg-spark-3.2')
integrationCompileOnly project(':iceberg-api')
}
@@ -225,7 +243,7 @@ project(':iceberg-spark:iceberg-spark3-runtime') {
}
task integrationTest(type: Test) {
- description = "Test Spark3 Runtime Jar against Spark 3.0"
+ description = "Test Spark3 Runtime Jar against Spark 3.2"
group = "verification"
testClassesDirs = sourceSets.integration.output.classesDirs
classpath = sourceSets.integration.runtimeClasspath + files(shadowJar.archiveFile.get().asFile.path)
@@ -237,4 +255,3 @@ project(':iceberg-spark:iceberg-spark3-runtime') {
enabled = false
}
}
-
diff --git a/spark/v3.2/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v3.2/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
new file mode 100644
index 0000000..d0b228d
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * This file is an adaptation of Presto's and Spark's grammar files.
+ */
+
+grammar IcebergSqlExtensions;
+
+@lexer::members {
+ /**
+ * Verify whether current token is a valid decimal token (which contains dot).
+ * Returns true if the character that follows the token is not a digit or letter or underscore.
+ *
+ * For example:
+ * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
+ * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
+ * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
+ * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
+ * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
+ * which is not a digit or letter or underscore.
+ */
+ public boolean isValidDecimal() {
+ int nextChar = _input.LA(1);
+ if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
+ nextChar == '_') {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * This method will be called when we see '/*' and try to match it as a bracketed comment.
+ * If the next character is '+', it should be parsed as hint later, and we cannot match
+ * it as a bracketed comment.
+ *
+ * Returns true if the next character is '+'.
+ */
+ public boolean isHint() {
+ int nextChar = _input.LA(1);
+ if (nextChar == '+') {
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call
+ | ALTER TABLE multipartIdentifier ADD PARTITION FIELD transform (AS name=identifier)? #addPartitionField
+ | ALTER TABLE multipartIdentifier DROP PARTITION FIELD transform #dropPartitionField
+ | ALTER TABLE multipartIdentifier REPLACE PARTITION FIELD transform WITH transform (AS name=identifier)? #replacePartitionField
+ | ALTER TABLE multipartIdentifier WRITE writeSpec #setWriteDistributionAndOrdering
+ | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList #setIdentifierFields
+ | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList #dropIdentifierFields
+ ;
+
+writeSpec
+ : (writeDistributionSpec | writeOrderingSpec)*
+ ;
+
+writeDistributionSpec
+ : DISTRIBUTED BY PARTITION
+ ;
+
+writeOrderingSpec
+ : LOCALLY? ORDERED BY order
+ | UNORDERED
+ ;
+
+callArgument
+ : expression #positionalArgument
+ | identifier '=>' expression #namedArgument
+ ;
+
+order
+ : fields+=orderField (',' fields+=orderField)*
+ | '(' fields+=orderField (',' fields+=orderField)* ')'
+ ;
+
+orderField
+ : transform direction=(ASC | DESC)? (NULLS nullOrder=(FIRST | LAST))?
+ ;
+
+transform
+ : multipartIdentifier #identityTransform
+ | transformName=identifier
+ '(' arguments+=transformArgument (',' arguments+=transformArgument)* ')' #applyTransform
+ ;
+
+transformArgument
+ : multipartIdentifier
+ | constant
+ ;
+
+expression
+ : constant
+ | stringMap
+ ;
+
+constant
+ : number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ | identifier STRING #typeConstructor
+ ;
+
+stringMap
+ : MAP '(' constant (',' constant)* ')'
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+number
+ : MINUS? EXPONENT_VALUE #exponentLiteral
+ | MINUS? DECIMAL_VALUE #decimalLiteral
+ | MINUS? INTEGER_VALUE #integerLiteral
+ | MINUS? BIGINT_LITERAL #bigIntLiteral
+ | MINUS? SMALLINT_LITERAL #smallIntLiteral
+ | MINUS? TINYINT_LITERAL #tinyIntLiteral
+ | MINUS? DOUBLE_LITERAL #doubleLiteral
+ | MINUS? FLOAT_LITERAL #floatLiteral
+ | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
+ ;
+
+multipartIdentifier
+ : parts+=identifier ('.' parts+=identifier)*
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+fieldList
+ : fields+=multipartIdentifier (',' fields+=multipartIdentifier)*
+ ;
+
+nonReserved
+ : ADD | ALTER | AS | ASC | BY | CALL | DESC | DROP | FIELD | FIRST | LAST | NULLS | ORDERED | PARTITION | TABLE | WRITE
+ | DISTRIBUTED | LOCALLY | UNORDERED | REPLACE | WITH | IDENTIFIER_KW | FIELDS | SET
+ | TRUE | FALSE
+ | MAP
+ ;
+
+ADD: 'ADD';
+ALTER: 'ALTER';
+AS: 'AS';
+ASC: 'ASC';
+BY: 'BY';
+CALL: 'CALL';
+DESC: 'DESC';
+DISTRIBUTED: 'DISTRIBUTED';
+DROP: 'DROP';
+FIELD: 'FIELD';
+FIELDS: 'FIELDS';
+FIRST: 'FIRST';
+LAST: 'LAST';
+LOCALLY: 'LOCALLY';
+NULLS: 'NULLS';
+ORDERED: 'ORDERED';
+PARTITION: 'PARTITION';
+REPLACE: 'REPLACE';
+IDENTIFIER_KW: 'IDENTIFIER';
+SET: 'SET';
+TABLE: 'TABLE';
+UNORDERED: 'UNORDERED';
+WITH: 'WITH';
+WRITE: 'WRITE';
+
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+
+MAP: 'MAP';
+
+PLUS: '+';
+MINUS: '-';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '"' ( ~('"'|'\\') | ('\\' .) )* '"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+EXPONENT_VALUE
+ : DIGIT+ EXPONENT
+ | DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
+ ;
+
+DECIMAL_VALUE
+ : DECIMAL_DIGITS {isValidDecimal()}?
+ ;
+
+FLOAT_LITERAL
+ : DIGIT+ EXPONENT? 'F'
+ | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
+ ;
+
+DOUBLE_LITERAL
+ : DIGIT+ EXPONENT? 'D'
+ | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
+ ;
+
+BIGDECIMAL_LITERAL
+ : DIGIT+ EXPONENT? 'BD'
+ | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment DECIMAL_DIGITS
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
new file mode 100644
index 0000000..4211388
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion
+import org.apache.spark.sql.catalyst.analysis.ResolveProcedures
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser
+import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
+
+class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ // parser extensions
+ extensions.injectParser { case (_, parser) => new IcebergSparkSqlExtensionsParser(parser) }
+
+ // analyzer extensions
+ extensions.injectResolutionRule { spark => ResolveProcedures(spark) }
+ extensions.injectResolutionRule { _ => ProcedureArgumentCoercion }
+
+ // planner extensions
+ extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) }
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala
new file mode 100644
index 0000000..7f0ca8f
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ProcedureArgumentCoercion.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.plans.logical.Call
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case c @ Call(procedure, args) if c.resolved =>
+ val params = procedure.parameters
+
+ val newArgs = args.zipWithIndex.map { case (arg, index) =>
+ val param = params(index)
+ val paramType = param.dataType
+ val argType = arg.dataType
+
+ if (paramType != argType && !Cast.canUpCast(argType, paramType)) {
+ throw new AnalysisException(
+ s"Wrong arg type for ${param.name}: cannot cast $argType to $paramType")
+ }
+
+ if (paramType != argType) {
+ Cast(arg, paramType)
+ } else {
+ arg
+ }
+ }
+
+ if (newArgs != args) {
+ c.copy(args = newArgs)
+ } else {
+ c
+ }
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala
new file mode 100644
index 0000000..b50655d
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveProcedures.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import java.util.Locale
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans.logical.Call
+import org.apache.spark.sql.catalyst.plans.logical.CallArgument
+import org.apache.spark.sql.catalyst.plans.logical.CallStatement
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.NamedArgument
+import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.CatalogManager
+import org.apache.spark.sql.connector.catalog.CatalogPlugin
+import org.apache.spark.sql.connector.catalog.LookupCatalog
+import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog
+import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter
+import scala.collection.Seq
+
+case class ResolveProcedures(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {
+
+ protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case CallStatement(CatalogAndIdentifier(catalog, ident), args) =>
+ val procedure = catalog.asProcedureCatalog.loadProcedure(ident)
+
+ val params = procedure.parameters
+ val normalizedParams = normalizeParams(params)
+ validateParams(normalizedParams)
+
+ val normalizedArgs = normalizeArgs(args)
+ Call(procedure, args = buildArgExprs(normalizedParams, normalizedArgs))
+ }
+
+ private def validateParams(params: Seq[ProcedureParameter]): Unit = {
+ // should not be any duplicate param names
+ val duplicateParamNames = params.groupBy(_.name).collect {
+ case (name, matchingParams) if matchingParams.length > 1 => name
+ }
+
+ if (duplicateParamNames.nonEmpty) {
+ throw new AnalysisException(s"Duplicate parameter names: ${duplicateParamNames.mkString("[", ",", "]")}")
+ }
+
+ // optional params should be at the end
+ params.sliding(2).foreach {
+ case Seq(previousParam, currentParam) if !previousParam.required && currentParam.required =>
+ throw new AnalysisException(
+ s"Optional parameters must be after required ones but $currentParam is after $previousParam")
+ case _ =>
+ }
+ }
+
+ private def buildArgExprs(
+ params: Seq[ProcedureParameter],
+ args: Seq[CallArgument]): Seq[Expression] = {
+
+ // build a map of declared parameter names to their positions
+ val nameToPositionMap = params.map(_.name).zipWithIndex.toMap
+
+ // build a map of parameter names to args
+ val nameToArgMap = buildNameToArgMap(params, args, nameToPositionMap)
+
+ // verify all required parameters are provided
+ val missingParamNames = params.filter(_.required).collect {
+ case param if !nameToArgMap.contains(param.name) => param.name
+ }
+
+ if (missingParamNames.nonEmpty) {
+ throw new AnalysisException(s"Missing required parameters: ${missingParamNames.mkString("[", ",", "]")}")
+ }
+
+ val argExprs = new Array[Expression](params.size)
+
+ nameToArgMap.foreach { case (name, arg) =>
+ val position = nameToPositionMap(name)
+ argExprs(position) = arg.expr
+ }
+
+ // assign nulls to optional params that were not set
+ params.foreach {
+ case p if !p.required && !nameToArgMap.contains(p.name) =>
+ val position = nameToPositionMap(p.name)
+ argExprs(position) = Literal.create(null, p.dataType)
+ case _ =>
+ }
+
+ argExprs
+ }
+
+ private def buildNameToArgMap(
+ params: Seq[ProcedureParameter],
+ args: Seq[CallArgument],
+ nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = {
+
+ val containsNamedArg = args.exists(_.isInstanceOf[NamedArgument])
+ val containsPositionalArg = args.exists(_.isInstanceOf[PositionalArgument])
+
+ if (containsNamedArg && containsPositionalArg) {
+ throw new AnalysisException("Named and positional arguments cannot be mixed")
+ }
+
+ if (containsNamedArg) {
+ buildNameToArgMapUsingNames(args, nameToPositionMap)
+ } else {
+ buildNameToArgMapUsingPositions(args, params)
+ }
+ }
+
+ private def buildNameToArgMapUsingNames(
+ args: Seq[CallArgument],
+ nameToPositionMap: Map[String, Int]): Map[String, CallArgument] = {
+
+ val namedArgs = args.asInstanceOf[Seq[NamedArgument]]
+
+ val validationErrors = namedArgs.groupBy(_.name).collect {
+ case (name, matchingArgs) if matchingArgs.size > 1 => s"Duplicate procedure argument: $name"
+ case (name, _) if !nameToPositionMap.contains(name) => s"Unknown argument: $name"
+ }
+
+ if (validationErrors.nonEmpty) {
+ throw new AnalysisException(s"Could not build name to arg map: ${validationErrors.mkString(", ")}")
+ }
+
+ namedArgs.map(arg => arg.name -> arg).toMap
+ }
+
+ private def buildNameToArgMapUsingPositions(
+ args: Seq[CallArgument],
+ params: Seq[ProcedureParameter]): Map[String, CallArgument] = {
+
+ if (args.size > params.size) {
+ throw new AnalysisException("Too many arguments for procedure")
+ }
+
+ args.zipWithIndex.map { case (arg, position) =>
+ val param = params(position)
+ param.name -> arg
+ }.toMap
+ }
+
+ private def normalizeParams(params: Seq[ProcedureParameter]): Seq[ProcedureParameter] = {
+ params.map {
+ case param if param.required =>
+ val normalizedName = param.name.toLowerCase(Locale.ROOT)
+ ProcedureParameter.required(normalizedName, param.dataType)
+ case param =>
+ val normalizedName = param.name.toLowerCase(Locale.ROOT)
+ ProcedureParameter.optional(normalizedName, param.dataType)
+ }
+ }
+
+ private def normalizeArgs(args: Seq[CallArgument]): Seq[CallArgument] = {
+ args.map {
+ case a @ NamedArgument(name, _) => a.copy(name = name.toLowerCase(Locale.ROOT))
+ case other => other
+ }
+ }
+
+ implicit class CatalogHelper(plugin: CatalogPlugin) {
+ def asProcedureCatalog: ProcedureCatalog = plugin match {
+ case procedureCatalog: ProcedureCatalog =>
+ procedureCatalog
+ case _ =>
+ throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a ProcedureCatalog")
+ }
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
new file mode 100644
index 0000000..b950bd7
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import java.util.Locale
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.misc.ParseCancellationException
+import org.antlr.v4.runtime.tree.TerminalNodeImpl
+import org.apache.iceberg.common.DynConstructors
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.VariableSubstitution
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.StructType
+
+class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {
+
+ import IcebergSparkSqlExtensionsParser._
+
+ private lazy val substitutor = substitutorCtor.newInstance(SQLConf.get)
+ private lazy val astBuilder = new IcebergSqlExtensionsAstBuilder(delegate)
+
+ /**
+ * Parse a string to a DataType.
+ */
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
+ */
+ def parseRawDataType(sqlText: String): DataType = throw new UnsupportedOperationException()
+
+ /**
+ * Parse a string to an Expression.
+ */
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ /**
+ * Parse a string to a TableIdentifier.
+ */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a FunctionIdentifier.
+ */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ /**
+ * Creates StructType for a given SQL string, which is a comma separated list of field
+ * definitions which will preserve the correct Hive metadata.
+ */
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ /**
+ * Parse a string to a LogicalPlan.
+ */
+ override def parsePlan(sqlText: String): LogicalPlan = {
+ val sqlTextAfterSubstitution = substitutor.substitute(sqlText)
+ if (isIcebergCommand(sqlTextAfterSubstitution)) {
+ parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) }.asInstanceOf[LogicalPlan]
+ } else {
+ delegate.parsePlan(sqlText)
+ }
+ }
+
+ private def isIcebergCommand(sqlText: String): Boolean = {
+ val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ")
+ normalized.startsWith("call") || (
+ normalized.startsWith("alter table") && (
+ normalized.contains("add partition field") ||
+ normalized.contains("drop partition field") ||
+ normalized.contains("replace partition field") ||
+ normalized.contains("write ordered by") ||
+ normalized.contains("write locally ordered by") ||
+ normalized.contains("write distributed by") ||
+ normalized.contains("write unordered") ||
+ normalized.contains("set identifier fields") ||
+ normalized.contains("drop identifier fields")))
+ }
+
+ protected def parse[T](command: String)(toResult: IcebergSqlExtensionsParser => T): T = {
+ val lexer = new IcebergSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(IcebergParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new IcebergSqlExtensionsParser(tokenStream)
+ parser.addParseListener(IcebergSqlExtensionsPostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(IcebergParseErrorListener)
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: IcebergParseException if e.command.isDefined =>
+ throw e
+ case e: IcebergParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new IcebergParseException(Option(command), e.message, position, position)
+ }
+ }
+}
+
+object IcebergSparkSqlExtensionsParser {
+ private val substitutorCtor: DynConstructors.Ctor[VariableSubstitution] =
+ DynConstructors.builder()
+ .impl(classOf[VariableSubstitution])
+ .impl(classOf[VariableSubstitution], classOf[SQLConf])
+ .build()
+}
+
+/* Copied from Apache Spark's to avoid dependency on Spark Internals */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume
+ override def getSourceName(): String = wrapped.getSourceName
+ override def index(): Int = wrapped.index
+ override def mark(): Int = wrapped.mark
+ override def release(marker: Int): Unit = wrapped.release(marker)
+ override def seek(where: Int): Unit = wrapped.seek(where)
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = wrapped.getText(interval)
+
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+ // scalastyle:on
+}
+
+/**
+ * The post-processor validates & cleans-up the parse tree during the parse process.
+ */
+case object IcebergSqlExtensionsPostProcessor extends IcebergSqlExtensionsBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ val newToken = new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ IcebergSqlExtensionsParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)
+ parent.addChild(new TerminalNodeImpl(f(newToken)))
+ }
+}
+
+/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */
+case object IcebergParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val (start, stop) = offendingSymbol match {
+ case token: CommonToken =>
+ val start = Origin(Some(line), Some(token.getCharPositionInLine))
+ val length = token.getStopIndex - token.getStartIndex + 1
+ val stop = Origin(Some(line), Some(token.getCharPositionInLine + length))
+ (start, stop)
+ case _ =>
+ val start = Origin(Some(line), Some(charPositionInLine))
+ (start, start)
+ }
+ throw new IcebergParseException(None, msg, start, stop)
+ }
+}
+
+/**
+ * Copied from Apache Spark
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class IcebergParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(IcebergParserUtils.command(ctx)),
+ message,
+ IcebergParserUtils.position(ctx.getStart),
+ IcebergParserUtils.position(ctx.getStop))
+ }
+
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
+ }
+
+ def withCommand(cmd: String): IcebergParseException = {
+ new IcebergParseException(Option(cmd), message, start, stop)
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
new file mode 100644
index 0000000..678da9b
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.ParseTree
+import org.antlr.v4.runtime.tree.TerminalNode
+import org.apache.iceberg.DistributionMode
+import org.apache.iceberg.NullOrder
+import org.apache.iceberg.SortDirection
+import org.apache.iceberg.expressions.Term
+import org.apache.iceberg.spark.Spark3Util
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._
+import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField
+import org.apache.spark.sql.catalyst.plans.logical.CallArgument
+import org.apache.spark.sql.catalyst.plans.logical.CallStatement
+import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields
+import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.NamedArgument
+import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument
+import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField
+import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields
+import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.connector.expressions
+import org.apache.spark.sql.connector.expressions.ApplyTransform
+import org.apache.spark.sql.connector.expressions.FieldReference
+import org.apache.spark.sql.connector.expressions.IdentityTransform
+import org.apache.spark.sql.connector.expressions.LiteralValue
+import org.apache.spark.sql.connector.expressions.Transform
+import scala.collection.JavaConverters._
+
+class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergSqlExtensionsBaseVisitor[AnyRef] {
+
+ /**
+ * Create a [[CallStatement]] for a stored procedure call.
+ */
+ override def visitCall(ctx: CallContext): CallStatement = withOrigin(ctx) {
+ val name = ctx.multipartIdentifier.parts.asScala.map(_.getText)
+ val args = ctx.callArgument.asScala.map(typedVisit[CallArgument])
+ CallStatement(name, args)
+ }
+
+ /**
+ * Create an ADD PARTITION FIELD logical command.
+ */
+ override def visitAddPartitionField(ctx: AddPartitionFieldContext): AddPartitionField = withOrigin(ctx) {
+ AddPartitionField(
+ typedVisit[Seq[String]](ctx.multipartIdentifier),
+ typedVisit[Transform](ctx.transform),
+ Option(ctx.name).map(_.getText))
+ }
+
+ /**
+ * Create a DROP PARTITION FIELD logical command.
+ */
+ override def visitDropPartitionField(ctx: DropPartitionFieldContext): DropPartitionField = withOrigin(ctx) {
+ DropPartitionField(
+ typedVisit[Seq[String]](ctx.multipartIdentifier),
+ typedVisit[Transform](ctx.transform))
+ }
+
+
+ /**
+ * Create an REPLACE PARTITION FIELD logical command.
+ */
+ override def visitReplacePartitionField(ctx: ReplacePartitionFieldContext): ReplacePartitionField = withOrigin(ctx) {
+ ReplacePartitionField(
+ typedVisit[Seq[String]](ctx.multipartIdentifier),
+ typedVisit[Transform](ctx.transform(0)),
+ typedVisit[Transform](ctx.transform(1)),
+ Option(ctx.name).map(_.getText))
+ }
+
+ /**
+ * Create an SET IDENTIFIER FIELDS logical command.
+ */
+ override def visitSetIdentifierFields(ctx: SetIdentifierFieldsContext): SetIdentifierFields = withOrigin(ctx) {
+ SetIdentifierFields(
+ typedVisit[Seq[String]](ctx.multipartIdentifier),
+ ctx.fieldList.fields.asScala.map(_.getText))
+ }
+
+ /**
+ * Create an DROP IDENTIFIER FIELDS logical command.
+ */
+ override def visitDropIdentifierFields(ctx: DropIdentifierFieldsContext): DropIdentifierFields = withOrigin(ctx) {
+ DropIdentifierFields(
+ typedVisit[Seq[String]](ctx.multipartIdentifier),
+ ctx.fieldList.fields.asScala.map(_.getText))
+ }
+
+ /**
+ * Create a [[SetWriteDistributionAndOrdering]] for changing the write distribution and ordering.
+ */
+ override def visitSetWriteDistributionAndOrdering(
+ ctx: SetWriteDistributionAndOrderingContext): SetWriteDistributionAndOrdering = {
+
+ val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
+
+ val (distributionSpec, orderingSpec) = toDistributionAndOrderingSpec(ctx.writeSpec)
+
+ if (distributionSpec == null && orderingSpec == null) {
+ throw new AnalysisException(
+ "ALTER TABLE has no changes: missing both distribution and ordering clauses")
+ }
+
+ val distributionMode = if (distributionSpec != null) {
+ DistributionMode.HASH
+ } else if (orderingSpec.UNORDERED != null || orderingSpec.LOCALLY != null) {
+ DistributionMode.NONE
+ } else {
+ DistributionMode.RANGE
+ }
+
+ val ordering = if (orderingSpec != null && orderingSpec.order != null) {
+ orderingSpec.order.fields.asScala.map(typedVisit[(Term, SortDirection, NullOrder)])
+ } else {
+ Seq.empty
+ }
+
+ SetWriteDistributionAndOrdering(tableName, distributionMode, ordering)
+ }
+
+ private def toDistributionAndOrderingSpec(
+ writeSpec: WriteSpecContext): (WriteDistributionSpecContext, WriteOrderingSpecContext) = {
+
+ if (writeSpec.writeDistributionSpec.size > 1) {
+ throw new AnalysisException("ALTER TABLE contains multiple distribution clauses")
+ }
+
+ if (writeSpec.writeOrderingSpec.size > 1) {
+ throw new AnalysisException("ALTER TABLE contains multiple ordering clauses")
+ }
+
+ val distributionSpec = writeSpec.writeDistributionSpec.asScala.headOption.orNull
+ val orderingSpec = writeSpec.writeOrderingSpec.asScala.headOption.orNull
+
+ (distributionSpec, orderingSpec)
+ }
+
+ /**
+ * Create an order field.
+ */
+ override def visitOrderField(ctx: OrderFieldContext): (Term, SortDirection, NullOrder) = {
+ val term = Spark3Util.toIcebergTerm(typedVisit[Transform](ctx.transform))
+ val direction = Option(ctx.ASC).map(_ => SortDirection.ASC)
+ .orElse(Option(ctx.DESC).map(_ => SortDirection.DESC))
+ .getOrElse(SortDirection.ASC)
+ val nullOrder = Option(ctx.FIRST).map(_ => NullOrder.NULLS_FIRST)
+ .orElse(Option(ctx.LAST).map(_ => NullOrder.NULLS_LAST))
+ .getOrElse(if (direction == SortDirection.ASC) NullOrder.NULLS_FIRST else NullOrder.NULLS_LAST)
+ (term, direction, nullOrder)
+ }
+
+ /**
+ * Create an IdentityTransform for a column reference.
+ */
+ override def visitIdentityTransform(ctx: IdentityTransformContext): Transform = withOrigin(ctx) {
+ IdentityTransform(FieldReference(typedVisit[Seq[String]](ctx.multipartIdentifier())))
+ }
+
+ /**
+ * Create a named Transform from argument expressions.
+ */
+ override def visitApplyTransform(ctx: ApplyTransformContext): Transform = withOrigin(ctx) {
+ val args = ctx.arguments.asScala.map(typedVisit[expressions.Expression])
+ ApplyTransform(ctx.transformName.getText, args)
+ }
+
+ /**
+ * Create a transform argument from a column reference or a constant.
+ */
+ override def visitTransformArgument(ctx: TransformArgumentContext): expressions.Expression = withOrigin(ctx) {
+ val reference = Option(ctx.multipartIdentifier())
+ .map(typedVisit[Seq[String]])
+ .map(FieldReference(_))
+ val literal = Option(ctx.constant)
+ .map(visitConstant)
+ .map(lit => LiteralValue(lit.value, lit.dataType))
+ reference.orElse(literal)
+ .getOrElse(throw new IcebergParseException(s"Invalid transform argument", ctx))
+ }
+
+ /**
+ * Return a multi-part identifier as Seq[String].
+ */
+ override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
+ ctx.parts.asScala.map(_.getText)
+ }
+
+ /**
+ * Create a positional argument in a stored procedure call.
+ */
+ override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) {
+ val expr = typedVisit[Expression](ctx.expression)
+ PositionalArgument(expr)
+ }
+
+ /**
+ * Create a named argument in a stored procedure call.
+ */
+ override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) {
+ val name = ctx.identifier.getText
+ val expr = typedVisit[Expression](ctx.expression)
+ NamedArgument(name, expr)
+ }
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ def visitConstant(ctx: ConstantContext): Literal = {
+ delegate.parseExpression(ctx.getText).asInstanceOf[Literal]
+ }
+
+ override def visitExpression(ctx: ExpressionContext): Expression = {
+ // reconstruct the SQL string and parse it using the main Spark parser
+ // while we can avoid the logic to build Spark expressions, we still have to parse them
+ // we cannot call ctx.getText directly since it will not render spaces correctly
+ // that's why we need to recurse down the tree in reconstructSqlString
+ val sqlString = reconstructSqlString(ctx)
+ delegate.parseExpression(sqlString)
+ }
+
+ private def reconstructSqlString(ctx: ParserRuleContext): String = {
+ ctx.children.asScala.map {
+ case c: ParserRuleContext => reconstructSqlString(c)
+ case t: TerminalNode => t.getText
+ }.mkString(" ")
+ }
+
+ private def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+}
+
+/* Partially copied from Apache Spark's Parser to avoid dependency on Spark Internals */
+object IcebergParserUtils {
+
+ private[sql] def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
+ }
+ }
+
+ private[sql] def position(token: Token): Origin = {
+ val opt = Option(token)
+ Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine))
+ }
+
+ /** Get the command which created the token. */
+ private[sql] def command(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(0, stream.size() - 1))
+ }
+}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala
similarity index 58%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala
index 30ea7fe..e8b1b29 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AddPartitionField.scala
@@ -17,13 +17,17 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class AddPartitionField(table: Seq[String], transform: Transform, name: Option[String]) extends LeafCommand {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"AddPartitionField ${table.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}"
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala
similarity index 56%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala
index 30ea7fe..551996e 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Call.scala
@@ -17,13 +17,18 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.connector.iceberg.catalog.Procedure
+import scala.collection.Seq
+
+case class Call(procedure: Procedure, args: Seq[Expression]) extends LeafCommand {
+ override lazy val output: Seq[Attribute] = procedure.outputType.toAttributes
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"Call${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}"
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala
similarity index 63%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala
index 30ea7fe..29dd686 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropIdentifierFields.scala
@@ -17,13 +17,18 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class DropIdentifierFields(
+ table: Seq[String],
+ fields: Seq[String]) extends LeafCommand {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"DropIdentifierFields ${table.quoted} (${fields.quoted})"
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala
similarity index 61%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala
index 30ea7fe..fb14513 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DropPartitionField.scala
@@ -17,13 +17,17 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class DropPartitionField(table: Seq[String], transform: Transform) extends LeafCommand {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"DropPartitionField ${table.quoted} ${transform.describe}"
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala
similarity index 54%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala
index 30ea7fe..8c660c6 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplacePartitionField.scala
@@ -17,13 +17,22 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class ReplacePartitionField(
+ table: Seq[String],
+ transformFrom: Transform,
+ transformTo: Transform,
+ name: Option[String]) extends LeafCommand {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"ReplacePartitionField ${table.quoted} ${transformFrom.describe} " +
+ s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}"
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala
similarity index 61%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala
index 30ea7fe..a5fa28a 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetIdentifierFields.scala
@@ -17,13 +17,19 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class SetIdentifierFields(
+ table: Seq[String],
+ fields: Seq[String]) extends LeafCommand {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"SetIdentifierFields ${table.quoted} (${fields.quoted})"
+ }
}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala
new file mode 100644
index 0000000..0a0234c
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SetWriteDistributionAndOrdering.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.iceberg.DistributionMode
+import org.apache.iceberg.NullOrder
+import org.apache.iceberg.SortDirection
+import org.apache.iceberg.expressions.Term
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits
+
+case class SetWriteDistributionAndOrdering(
+ table: Seq[String],
+ distributionMode: DistributionMode,
+ sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafCommand {
+
+ import CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override def simpleString(maxFields: Int): String = {
+ val order = sortOrder.map {
+ case (term, direction, nullOrder) => s"$term $direction $nullOrder"
+ }.mkString(", ")
+ s"SetWriteDistributionAndOrdering ${table.quoted} $distributionMode $order"
+ }
+}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
similarity index 56%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index 30ea7fe..be15f32 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -17,13 +17,28 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.catalyst.plans.logical
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+/**
+ * A CALL statement, as parsed from SQL.
+ */
+case class CallStatement(name: Seq[String], args: Seq[CallArgument]) extends LeafParsedStatement
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+/**
+ * An argument in a CALL statement.
+ */
+sealed trait CallArgument {
+ def expr: Expression
}
+
+/**
+ * An argument in a CALL statement identified by name.
+ */
+case class NamedArgument(name: String, expr: Expression) extends CallArgument
+
+/**
+ * An argument in a CALL statement identified by position.
+ */
+case class PositionalArgument(expr: Expression) extends CallArgument
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala
new file mode 100644
index 0000000..55f327f
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionFieldExec.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.Spark3Util
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class AddPartitionFieldExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ transform: Transform,
+ name: Option[String]) extends LeafV2CommandExec {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ iceberg.table.updateSpec()
+ .addField(name.orNull, Spark3Util.toIcebergTerm(transform))
+ .commit()
+
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot add partition field to non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"AddPartitionField ${catalog.name}.${ident.quoted} ${name.map(n => s"$n=").getOrElse("")}${transform.describe}"
+ }
+}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
similarity index 55%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
index 30ea7fe..8a88f35 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
@@ -17,13 +17,23 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.spark.sql.execution.datasources.v2
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.connector.iceberg.catalog.Procedure
+
+case class CallExec(
+ output: Seq[Attribute],
+ procedure: Procedure,
+ input: InternalRow) extends LeafV2CommandExec {
+
+ override protected def run(): Seq[InternalRow] = {
+ procedure.call(input)
+ }
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ override def simpleString(maxFields: Int): String = {
+ s"CallExec${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}"
+ }
}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala
new file mode 100644
index 0000000..dee778b
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropIdentifierFieldsExec.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions
+import org.apache.iceberg.relocated.com.google.common.collect.Sets
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+
+case class DropIdentifierFieldsExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ fields: Seq[String]) extends LeafV2CommandExec {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ val schema = iceberg.table.schema
+ val identifierFieldNames = Sets.newHashSet(schema.identifierFieldNames)
+
+ for (name <- fields) {
+ Preconditions.checkArgument(schema.findField(name) != null,
+ "Cannot complete drop identifier fields operation: field %s not found", name)
+ Preconditions.checkArgument(identifierFieldNames.contains(name),
+ "Cannot complete drop identifier fields operation: %s is not an identifier field", name)
+ identifierFieldNames.remove(name)
+ }
+
+ iceberg.table.updateSchema()
+ .setIdentifierFields(identifierFieldNames)
+ .commit();
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot drop identifier fields in non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"DropIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})";
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala
new file mode 100644
index 0000000..9a153f0
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionFieldExec.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.Spark3Util
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.connector.expressions.FieldReference
+import org.apache.spark.sql.connector.expressions.IdentityTransform
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class DropPartitionFieldExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ transform: Transform) extends LeafV2CommandExec {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ val schema = iceberg.table.schema
+ transform match {
+ case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null =>
+ // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name
+ iceberg.table.updateSpec()
+ .removeField(parts.head)
+ .commit()
+
+ case _ =>
+ iceberg.table.updateSpec()
+ .removeField(Spark3Util.toIcebergTerm(transform))
+ .commit()
+ }
+
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot drop partition field in non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"DropPartitionField ${catalog.name}.${ident.quoted} ${transform.describe}"
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
new file mode 100644
index 0000000..63077ec
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.Spark3Util
+import org.apache.iceberg.spark.SparkCatalog
+import org.apache.iceberg.spark.SparkSessionCatalog
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField
+import org.apache.spark.sql.catalyst.plans.logical.Call
+import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields
+import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField
+import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields
+import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.execution.SparkPlan
+import scala.collection.JavaConverters._
+
+case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case c @ Call(procedure, args) =>
+ val input = buildInternalRow(args)
+ CallExec(c.output, procedure, input) :: Nil
+
+ case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) =>
+ AddPartitionFieldExec(catalog, ident, transform, name) :: Nil
+
+ case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) =>
+ DropPartitionFieldExec(catalog, ident, transform) :: Nil
+
+ case ReplacePartitionField(IcebergCatalogAndIdentifier(catalog, ident), transformFrom, transformTo, name) =>
+ ReplacePartitionFieldExec(catalog, ident, transformFrom, transformTo, name) :: Nil
+
+ case SetIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) =>
+ SetIdentifierFieldsExec(catalog, ident, fields) :: Nil
+
+ case DropIdentifierFields(IcebergCatalogAndIdentifier(catalog, ident), fields) =>
+ DropIdentifierFieldsExec(catalog, ident, fields) :: Nil
+
+ case SetWriteDistributionAndOrdering(
+ IcebergCatalogAndIdentifier(catalog, ident), distributionMode, ordering) =>
+ SetWriteDistributionAndOrderingExec(catalog, ident, distributionMode, ordering) :: Nil
+
+ case _ => Nil
+ }
+
+ private def buildInternalRow(exprs: Seq[Expression]): InternalRow = {
+ val values = new Array[Any](exprs.size)
+ for (index <- exprs.indices) {
+ values(index) = exprs(index).eval()
+ }
+ new GenericInternalRow(values)
+ }
+
+ private object IcebergCatalogAndIdentifier {
+ def unapply(identifier: Seq[String]): Option[(TableCatalog, Identifier)] = {
+ val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(spark, identifier.asJava)
+ catalogAndIdentifier.catalog match {
+ case icebergCatalog: SparkCatalog =>
+ Some((icebergCatalog, catalogAndIdentifier.identifier))
+ case icebergCatalog: SparkSessionCatalog[_] =>
+ Some((icebergCatalog, catalogAndIdentifier.identifier))
+ case _ =>
+ None
+ }
+ }
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala
new file mode 100644
index 0000000..fcae0a5
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplacePartitionFieldExec.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.Spark3Util
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.connector.expressions.FieldReference
+import org.apache.spark.sql.connector.expressions.IdentityTransform
+import org.apache.spark.sql.connector.expressions.Transform
+
+case class ReplacePartitionFieldExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ transformFrom: Transform,
+ transformTo: Transform,
+ name: Option[String]) extends LeafV2CommandExec {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ val schema = iceberg.table.schema
+ transformFrom match {
+ case IdentityTransform(FieldReference(parts)) if parts.size == 1 && schema.findField(parts.head) == null =>
+ // the name is not present in the Iceberg schema, so it must be a partition field name, not a column name
+ iceberg.table.updateSpec()
+ .removeField(parts.head)
+ .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo))
+ .commit()
+
+ case _ =>
+ iceberg.table.updateSpec()
+ .removeField(Spark3Util.toIcebergTerm(transformFrom))
+ .addField(name.orNull, Spark3Util.toIcebergTerm(transformTo))
+ .commit()
+ }
+
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot replace partition field in non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"ReplacePartitionField ${catalog.name}.${ident.quoted} ${transformFrom.describe} " +
+ s"with ${name.map(n => s"$n=").getOrElse("")}${transformTo.describe}"
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala
new file mode 100644
index 0000000..4c23653
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetIdentifierFieldsExec.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+
+case class SetIdentifierFieldsExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ fields: Seq[String]) extends LeafV2CommandExec {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ iceberg.table.updateSchema()
+ .setIdentifierFields(scala.collection.JavaConverters.seqAsJavaList(fields))
+ .commit();
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot set identifier fields in non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"SetIdentifierFields ${catalog.name}.${ident.quoted} (${fields.quoted})";
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala
new file mode 100644
index 0000000..386485b
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetWriteDistributionAndOrderingExec.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.DistributionMode
+import org.apache.iceberg.NullOrder
+import org.apache.iceberg.SortDirection
+import org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE
+import org.apache.iceberg.expressions.Term
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+
+case class SetWriteDistributionAndOrderingExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ distributionMode: DistributionMode,
+ sortOrder: Seq[(Term, SortDirection, NullOrder)]) extends LeafV2CommandExec {
+
+ import CatalogV2Implicits._
+
+ override lazy val output: Seq[Attribute] = Nil
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.loadTable(ident) match {
+ case iceberg: SparkTable =>
+ val txn = iceberg.table.newTransaction()
+
+ val orderBuilder = txn.replaceSortOrder()
+ sortOrder.foreach {
+ case (term, SortDirection.ASC, nullOrder) =>
+ orderBuilder.asc(term, nullOrder)
+ case (term, SortDirection.DESC, nullOrder) =>
+ orderBuilder.desc(term, nullOrder)
+ }
+ orderBuilder.commit()
+
+ txn.updateProperties()
+ .set(WRITE_DISTRIBUTION_MODE, distributionMode.modeName())
+ .commit()
+
+ txn.commitTransaction()
+
+ case table =>
+ throw new UnsupportedOperationException(s"Cannot set write order of non-Iceberg table: $table")
+ }
+
+ Nil
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ val tableIdent = s"${catalog.name}.${ident.quoted}"
+ val order = sortOrder.map {
+ case (term, direction, nullOrder) => s"$term $direction $nullOrder"
+ }.mkString(", ")
+ s"SetWriteDistributionAndOrdering $tableIdent $distributionMode $order"
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java
new file mode 100644
index 0000000..51ac578
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/Employee.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Objects;
+
+public class Employee {
+ private Integer id;
+ private String dep;
+
+ public Employee() {
+ }
+
+ public Employee(Integer id, String dep) {
+ this.id = id;
+ this.dep = dep;
+ }
+
+ public Integer getId() {
+ return id;
+ }
+
+ public void setId(Integer id) {
+ this.id = id;
+ }
+
+ public String getDep() {
+ return dep;
+ }
+
+ public void setDep(String dep) {
+ this.dep = dep;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ } else if (other == null || getClass() != other.getClass()) {
+ return false;
+ }
+
+ Employee employee = (Employee) other;
+ return Objects.equals(id, employee.id) && Objects.equals(dep, employee.dep);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(id, dep);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java
new file mode 100644
index 0000000..36ca608
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkExtensionsTestBase.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Map;
+import org.apache.iceberg.CatalogUtil;
+import org.apache.iceberg.hive.HiveCatalog;
+import org.apache.iceberg.hive.TestHiveMetastore;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.spark.SparkCatalogTestBase;
+import org.apache.iceberg.spark.SparkTestBase;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.internal.SQLConf;
+import org.junit.BeforeClass;
+
+import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;
+
+public abstract class SparkExtensionsTestBase extends SparkCatalogTestBase {
+
+ public SparkExtensionsTestBase(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @BeforeClass
+ public static void startMetastoreAndSpark() {
+ SparkTestBase.metastore = new TestHiveMetastore();
+ metastore.start();
+ SparkTestBase.hiveConf = metastore.hiveConf();
+
+ SparkTestBase.spark = SparkSession.builder()
+ .master("local[2]")
+ .config("spark.testing", "true")
+ .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
+ .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName())
+ .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname))
+ .config("spark.sql.shuffle.partitions", "4")
+ .config("spark.sql.hive.metastorePartitionPruningFallbackOnException", "true")
+ .enableHiveSupport()
+ .getOrCreate();
+
+ SparkTestBase.catalog = (HiveCatalog)
+ CatalogUtil.loadCatalog(HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
new file mode 100644
index 0000000..1dc0bd4
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.stream.Collectors;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.spark.SparkCatalog;
+import org.apache.iceberg.spark.SparkSessionCatalog;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.junit.Assert;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT;
+import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED;
+import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
+import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH;
+import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_NONE;
+import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE;
+
+@RunWith(Parameterized.class)
+public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTestBase {
+
+ private static final Random RANDOM = ThreadLocalRandom.current();
+
+ protected final String fileFormat;
+ protected final boolean vectorized;
+ protected final String distributionMode;
+
+ public SparkRowLevelOperationsTestBase(String catalogName, String implementation,
+ Map<String, String> config, String fileFormat,
+ boolean vectorized,
+ String distributionMode) {
+ super(catalogName, implementation, config);
+ this.fileFormat = fileFormat;
+ this.vectorized = vectorized;
+ this.distributionMode = distributionMode;
+ }
+
+ @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}," +
+ " format = {3}, vectorized = {4}, distributionMode = {5}")
+ public static Object[][] parameters() {
+ return new Object[][] {
+ { "testhive", SparkCatalog.class.getName(),
+ ImmutableMap.of(
+ "type", "hive",
+ "default-namespace", "default"
+ ),
+ "orc",
+ true,
+ WRITE_DISTRIBUTION_MODE_NONE
+ },
+ { "testhadoop", SparkCatalog.class.getName(),
+ ImmutableMap.of(
+ "type", "hadoop"
+ ),
+ "parquet",
+ RANDOM.nextBoolean(),
+ WRITE_DISTRIBUTION_MODE_HASH
+ },
+ { "spark_catalog", SparkSessionCatalog.class.getName(),
+ ImmutableMap.of(
+ "type", "hive",
+ "default-namespace", "default",
+ "clients", "1",
+ "parquet-enabled", "false",
+ "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync
+ ),
+ "avro",
+ false,
+ WRITE_DISTRIBUTION_MODE_RANGE
+ }
+ };
+ }
+
+ protected abstract Map<String, String> extraTableProperties();
+
+ protected void initTable() {
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DEFAULT_FILE_FORMAT, fileFormat);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, distributionMode);
+
+ switch (fileFormat) {
+ case "parquet":
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tableName, PARQUET_VECTORIZATION_ENABLED, vectorized);
+ break;
+ case "orc":
+ Assert.assertTrue(vectorized);
+ break;
+ case "avro":
+ Assert.assertFalse(vectorized);
+ break;
+ }
+
+ Map<String, String> props = extraTableProperties();
+ props.forEach((prop, value) -> {
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, prop, value);
+ });
+ }
+
+ protected void createAndInitTable(String schema) {
+ createAndInitTable(schema, null);
+ }
+
+ protected void createAndInitTable(String schema, String jsonData) {
+ sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema);
+ initTable();
+
+ if (jsonData != null) {
+ try {
+ Dataset<Row> ds = toDS(schema, jsonData);
+ ds.writeTo(tableName).append();
+ } catch (NoSuchTableException e) {
+ throw new RuntimeException("Failed to write data", e);
+ }
+ }
+ }
+
+ protected void append(String table, String jsonData) {
+ append(table, null, jsonData);
+ }
+
+ protected void append(String table, String schema, String jsonData) {
+ try {
+ Dataset<Row> ds = toDS(schema, jsonData);
+ ds.coalesce(1).writeTo(table).append();
+ } catch (NoSuchTableException e) {
+ throw new RuntimeException("Failed to write data", e);
+ }
+ }
+
+ protected void createOrReplaceView(String name, String jsonData) {
+ createOrReplaceView(name, null, jsonData);
+ }
+
+ protected void createOrReplaceView(String name, String schema, String jsonData) {
+ Dataset<Row> ds = toDS(schema, jsonData);
+ ds.createOrReplaceTempView(name);
+ }
+
+ protected <T> void createOrReplaceView(String name, List<T> data, Encoder<T> encoder) {
+ spark.createDataset(data, encoder).createOrReplaceTempView(name);
+ }
+
+ private Dataset<Row> toDS(String schema, String jsonData) {
+ List<String> jsonRows = Arrays.stream(jsonData.split("\n"))
+ .filter(str -> str.trim().length() > 0)
+ .collect(Collectors.toList());
+ Dataset<String> jsonDS = spark.createDataset(jsonRows, Encoders.STRING());
+
+ if (schema != null) {
+ return spark.read().schema(schema).json(jsonDS);
+ } else {
+ return spark.read().json(jsonDS);
+ }
+ }
+
+ protected void validateSnapshot(Snapshot snapshot, String operation, String changedPartitionCount,
+ String deletedDataFiles, String addedDataFiles) {
+ Assert.assertEquals("Operation must match", operation, snapshot.operation());
+ Assert.assertEquals("Changed partitions count must match",
+ changedPartitionCount,
+ snapshot.summary().get("changed-partition-count"));
+ Assert.assertEquals("Deleted data files count must match",
+ deletedDataFiles,
+ snapshot.summary().get("deleted-data-files"));
+ Assert.assertEquals("Added data files count must match",
+ addedDataFiles,
+ snapshot.summary().get("added-data-files"));
+ }
+
+ protected void sleep(long millis) {
+ try {
+ Thread.sleep(millis);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
new file mode 100644
index 0000000..7f5c7df
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
@@ -0,0 +1,738 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.avro.Schema;
+import org.apache.avro.SchemaBuilder;
+import org.apache.avro.file.DataFileWriter;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericDatumWriter;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.io.DatumWriter;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class TestAddFilesProcedure extends SparkExtensionsTestBase {
+
+ private final String sourceTableName = "source_table";
+ private File fileTableDir;
+
+ public TestAddFilesProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ @Before
+ public void setupTempDirs() {
+ try {
+ fileTableDir = temp.newFolder();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @After
+ public void dropTables() {
+ sql("DROP TABLE IF EXISTS %s", sourceTableName);
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void addDataUnpartitioned() {
+ createUnpartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table
+ public void addDataUnpartitionedOrc() {
+ createUnpartitionedFileTable("orc");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`orc`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addAvroFile() throws Exception {
+ // Spark Session Catalog cannot load metadata tables
+ // with "The namespace in session catalog must have exactly one name part"
+ Assume.assumeFalse(catalogName.equals("spark_catalog"));
+
+ // Create an Avro file
+
+ Schema schema = SchemaBuilder.record("record").fields()
+ .requiredInt("id")
+ .requiredString("data")
+ .endRecord();
+ GenericRecord record1 = new GenericData.Record(schema);
+ record1.put("id", 1L);
+ record1.put("data", "a");
+ GenericRecord record2 = new GenericData.Record(schema);
+ record2.put("id", 2L);
+ record2.put("data", "b");
+ File outputFile = temp.newFile("test.avro");
+
+ DatumWriter<GenericRecord> datumWriter = new GenericDatumWriter(schema);
+ DataFileWriter<GenericRecord> dataFileWriter = new DataFileWriter(datumWriter);
+ dataFileWriter.create(schema, outputFile);
+ dataFileWriter.append(record1);
+ dataFileWriter.append(record2);
+ dataFileWriter.close();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Long, data String) USING iceberg";
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')",
+ catalogName, tableName, outputFile.getPath());
+ Assert.assertEquals(1L, result);
+
+ List<Object[]> expected = Lists.newArrayList(
+ new Object[]{1L, "a"},
+ new Object[]{2L, "b"}
+ );
+
+ assertEquals("Iceberg table contains correct data",
+ expected,
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ List<Object[]> actualRecordCount = sql("select %s from %s.files",
+ DataFile.RECORD_COUNT.name(),
+ tableName);
+ List<Object[]> expectedRecordCount = Lists.newArrayList();
+ expectedRecordCount.add(new Object[]{2L});
+ assertEquals("Iceberg file metadata should have correct metadata count",
+ expectedRecordCount, actualRecordCount);
+ }
+
+ // TODO Adding spark-avro doesn't work in tests
+ @Ignore
+ public void addDataUnpartitionedAvro() {
+ createUnpartitionedFileTable("avro");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataUnpartitionedHive() {
+ createUnpartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '%s')",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataUnpartitionedExtraCol() {
+ createUnpartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String, foo string) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataUnpartitionedMissingCol() {
+ createUnpartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataPartitionedMissingCol() {
+ createPartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(8L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataPartitioned() {
+ createPartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(8L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Ignore // TODO Classpath issues prevent us from actually writing to a Spark ORC table
+ public void addDataPartitionedOrc() {
+ createPartitionedFileTable("orc");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(8L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ // TODO Adding spark-avro doesn't work in tests
+ @Ignore
+ public void addDataPartitionedAvro() {
+ createPartitionedFileTable("avro");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`avro`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(8L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addDataPartitionedHive() {
+ createPartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '%s')",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(8L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addPartitionToPartitioned() {
+ createPartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addFilteredPartitionsToPartitioned() {
+ createCompositePartitionedTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " +
+ "PARTITIONED BY (id, dept)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addFilteredPartitionsToPartitioned2() {
+ createCompositePartitionedTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " +
+ "PARTITIONED BY (id, dept)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))",
+ catalogName, tableName, fileTableDir.getAbsolutePath());
+
+ Assert.assertEquals(6L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addWeirdCaseHiveTable() {
+ createWeirdCaseTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, `naMe` String, dept String, subdept String) USING iceberg " +
+ "PARTITIONED BY (`naMe`)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '%s', map('naMe', 'John Doe'))",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(2L, result);
+
+ /*
+ While we would like to use
+ SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id
+ Spark does not actually handle this pushdown correctly for hive based tables and it returns 0 records
+ */
+ List<Object[]> expected =
+ sql("SELECT id, `naMe`, dept, subdept from %s ORDER BY id", sourceTableName)
+ .stream()
+ .filter(r -> r[1].equals("John Doe"))
+ .collect(Collectors.toList());
+
+ // TODO when this assert breaks Spark fixed the pushdown issue
+ Assert.assertEquals("If this assert breaks it means that Spark has fixed the pushdown issue", 0,
+ sql("SELECT id, `naMe`, dept, subdept from %s WHERE `naMe` = 'John Doe' ORDER BY id", sourceTableName)
+ .size());
+
+ // Pushdown works for iceberg
+ Assert.assertEquals("We should be able to pushdown mixed case partition keys", 2,
+ sql("SELECT id, `naMe`, dept, subdept FROM %s WHERE `naMe` = 'John Doe' ORDER BY id", tableName)
+ .size());
+
+ assertEquals("Iceberg table contains correct data",
+ expected,
+ sql("SELECT id, `naMe`, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void addPartitionToPartitionedHive() {
+ createPartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result = scalarSql("CALL %s.system.add_files('%s', '%s', map('id', 1))",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(2L, result);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void invalidDataImport() {
+ createPartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ AssertHelpers.assertThrows("Should forbid adding of partitioned data to unpartitioned table",
+ IllegalArgumentException.class,
+ "Cannot use partition filter with an unpartitioned table",
+ () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('id', 1))",
+ catalogName, tableName, fileTableDir.getAbsolutePath())
+ );
+
+ AssertHelpers.assertThrows("Should forbid adding of partitioned data to unpartitioned table",
+ IllegalArgumentException.class,
+ "Cannot add partitioned files to an unpartitioned table",
+ () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`')",
+ catalogName, tableName, fileTableDir.getAbsolutePath())
+ );
+ }
+
+ @Test
+ public void invalidDataImportPartitioned() {
+ createUnpartitionedFileTable("parquet");
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ AssertHelpers.assertThrows("Should forbid adding with a mismatching partition spec",
+ IllegalArgumentException.class,
+ "is greater than the number of partitioned columns",
+ () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('x', '1', 'y', '2'))",
+ catalogName, tableName, fileTableDir.getAbsolutePath()));
+
+ AssertHelpers.assertThrows("Should forbid adding with partition spec with incorrect columns",
+ IllegalArgumentException.class,
+ "specified partition filter refers to columns that are not partitioned",
+ () -> scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', '2'))",
+ catalogName, tableName, fileTableDir.getAbsolutePath()));
+ }
+
+
+ @Test
+ public void addTwice() {
+ createPartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result1 = scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 1))",
+ catalogName, tableName, sourceTableName);
+ Assert.assertEquals(2L, result1);
+
+ Object result2 = scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 2))",
+ catalogName, tableName, sourceTableName);
+ Assert.assertEquals(2L, result2);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 ORDER BY id", tableName));
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 2 ORDER BY id", tableName));
+ }
+
+ @Test
+ public void duplicateDataPartitioned() {
+ createPartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 1))",
+ catalogName, tableName, sourceTableName);
+
+ AssertHelpers.assertThrows("Should not allow adding duplicate files",
+ IllegalStateException.class,
+ "Cannot complete import because data files to be imported already" +
+ " exist within the target table",
+ () -> scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 1))",
+ catalogName, tableName, sourceTableName));
+ }
+
+ @Test
+ public void duplicateDataPartitionedAllowed() {
+ createPartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg PARTITIONED BY (id)";
+
+ sql(createIceberg, tableName);
+
+ Object result1 = scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 1))",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(2L, result1);
+
+ Object result2 = scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s', " +
+ "partition_filter => map('id', 1)," +
+ "check_duplicate_files => false)",
+ catalogName, tableName, sourceTableName);
+
+ Assert.assertEquals(2L, result2);
+
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT id, name, dept, subdept FROM %s WHERE id = 1 UNION ALL " +
+ "SELECT id, name, dept, subdept FROM %s WHERE id = 1", sourceTableName, sourceTableName),
+ sql("SELECT id, name, dept, subdept FROM %s", tableName, tableName));
+ }
+
+ @Test
+ public void duplicateDataUnpartitioned() {
+ createUnpartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ scalarSql("CALL %s.system.add_files('%s', '%s')",
+ catalogName, tableName, sourceTableName);
+
+ AssertHelpers.assertThrows("Should not allow adding duplicate files",
+ IllegalStateException.class,
+ "Cannot complete import because data files to be imported already" +
+ " exist within the target table",
+ () -> scalarSql("CALL %s.system.add_files('%s', '%s')",
+ catalogName, tableName, sourceTableName));
+ }
+
+ @Test
+ public void duplicateDataUnpartitionedAllowed() {
+ createUnpartitionedHiveTable();
+
+ String createIceberg =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";
+
+ sql(createIceberg, tableName);
+
+ Object result1 = scalarSql("CALL %s.system.add_files('%s', '%s')",
+ catalogName, tableName, sourceTableName);
+ Assert.assertEquals(2L, result1);
+
+ Object result2 = scalarSql("CALL %s.system.add_files(" +
+ "table => '%s', " +
+ "source_table => '%s'," +
+ "check_duplicate_files => false)",
+ catalogName, tableName, sourceTableName);
+ Assert.assertEquals(2L, result2);
+
+ assertEquals("Iceberg table contains correct data",
+ sql("SELECT * FROM (SELECT * FROM %s UNION ALL " +
+ "SELECT * from %s) ORDER BY id", sourceTableName, sourceTableName),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+
+ }
+
+ private static final StructField[] struct = {
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("name", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("dept", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("subdept", DataTypes.StringType, false, Metadata.empty())
+ };
+
+ private static final Dataset<Row> unpartitionedDF =
+ spark.createDataFrame(
+ ImmutableList.of(
+ RowFactory.create(1, "John Doe", "hr", "communications"),
+ RowFactory.create(2, "Jane Doe", "hr", "salary"),
+ RowFactory.create(3, "Matt Doe", "hr", "communications"),
+ RowFactory.create(4, "Will Doe", "facilities", "all")),
+ new StructType(struct)).repartition(1);
+
+ private static final Dataset<Row> partitionedDF =
+ unpartitionedDF.select("name", "dept", "subdept", "id");
+
+ private static final Dataset<Row> compositePartitionedDF =
+ unpartitionedDF.select("name", "subdept", "id", "dept");
+
+ private static final Dataset<Row> weirdColumnNamesDF =
+ unpartitionedDF.select(
+ unpartitionedDF.col("id"),
+ unpartitionedDF.col("subdept"),
+ unpartitionedDF.col("dept"),
+ unpartitionedDF.col("name").as("naMe"));
+
+
+ private void createUnpartitionedFileTable(String format) {
+ String createParquet =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s LOCATION '%s'";
+
+ sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath());
+ unpartitionedDF.write().insertInto(sourceTableName);
+ unpartitionedDF.write().insertInto(sourceTableName);
+ }
+
+ private void createPartitionedFileTable(String format) {
+ String createParquet =
+ "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s PARTITIONED BY (id) " +
+ "LOCATION '%s'";
+
+ sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath());
+
+ partitionedDF.write().insertInto(sourceTableName);
+ partitionedDF.write().insertInto(sourceTableName);
+ }
+
+ private void createCompositePartitionedTable(String format) {
+ String createParquet = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING %s " +
+ "PARTITIONED BY (id, dept) LOCATION '%s'";
+ sql(createParquet, sourceTableName, format, fileTableDir.getAbsolutePath());
+
+ compositePartitionedDF.write().insertInto(sourceTableName);
+ compositePartitionedDF.write().insertInto(sourceTableName);
+ }
+
+ private void createWeirdCaseTable() {
+ String createParquet =
+ "CREATE TABLE %s (id Integer, subdept String, dept String) " +
+ "PARTITIONED BY (`naMe` String) STORED AS parquet";
+
+ sql(createParquet, sourceTableName);
+
+ weirdColumnNamesDF.write().insertInto(sourceTableName);
+ weirdColumnNamesDF.write().insertInto(sourceTableName);
+
+ }
+
+ private void createUnpartitionedHiveTable() {
+ String createHive = "CREATE TABLE %s (id Integer, name String, dept String, subdept String) STORED AS parquet";
+
+ sql(createHive, sourceTableName);
+
+ unpartitionedDF.write().insertInto(sourceTableName);
+ unpartitionedDF.write().insertInto(sourceTableName);
+ }
+
+ private void createPartitionedHiveTable() {
+ String createHive = "CREATE TABLE %s (name String, dept String, subdept String) " +
+ "PARTITIONED BY (id Integer) STORED AS parquet";
+
+ sql(createHive, sourceTableName);
+
+ partitionedDF.write().insertInto(sourceTableName);
+ partitionedDF.write().insertInto(sourceTableName);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java
new file mode 100644
index 0000000..9d63050
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTablePartitionFields.java
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Map;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.spark.source.SparkTable;
+import org.apache.spark.sql.connector.catalog.CatalogManager;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestAlterTablePartitionFields extends SparkExtensionsTestBase {
+ public TestAlterTablePartitionFields(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTable() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testAddIdentityPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .identity("category")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddBucketPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .bucket("id", 16, "id_bucket_16")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddTruncatePartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .truncate("data", 4, "data_trunc_4")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddYearsPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .year("ts")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddMonthsPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD months(ts)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .month("ts")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddDaysPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .day("ts")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddHoursPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD hours(ts)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .hour("ts")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testAddNamedPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .bucket("id", 16, "shard")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testDropIdentityPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg PARTITIONED BY (category)",
+ tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size());
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .alwaysNull("category", "category")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testDropDaysPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg PARTITIONED BY (days(ts))",
+ tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size());
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD days(ts)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .alwaysNull("ts", "ts_day")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testDropBucketPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (bucket(16, id))",
+ tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Table should start with 1 partition field", 1, table.spec().fields().size());
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD bucket(16, id)", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .alwaysNull("id", "id_bucket")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testDropPartitionByName() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName);
+
+ table.refresh();
+
+ Assert.assertEquals("Table should have 1 partition field", 1, table.spec().fields().size());
+
+ // Should be recognized as iceberg command even with extra white spaces
+ sql("ALTER TABLE %s DROP PARTITION \n FIELD shard", tableName);
+
+ table.refresh();
+
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(2)
+ .alwaysNull("id", "shard")
+ .build();
+
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+ }
+
+ @Test
+ public void testReplacePartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName);
+ table.refresh();
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .day("ts")
+ .build();
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+
+ sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts)", tableName);
+ table.refresh();
+ expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(2)
+ .alwaysNull("ts", "ts_day")
+ .hour("ts")
+ .build();
+ Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec());
+ }
+
+ @Test
+ public void testReplacePartitionAndRename() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName);
+ table.refresh();
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .day("ts")
+ .build();
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+
+ sql("ALTER TABLE %s REPLACE PARTITION FIELD days(ts) WITH hours(ts) AS hour_col", tableName);
+ table.refresh();
+ expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(2)
+ .alwaysNull("ts", "ts_day")
+ .hour("ts", "hour_col")
+ .build();
+ Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec());
+ }
+
+ @Test
+ public void testReplaceNamedPartition() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName);
+ table.refresh();
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .day("ts", "day_col")
+ .build();
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+
+ sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts)", tableName);
+ table.refresh();
+ expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(2)
+ .alwaysNull("ts", "day_col")
+ .hour("ts")
+ .build();
+ Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec());
+ }
+
+ @Test
+ public void testReplaceNamedPartitionAndRenameDifferently() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unpartitioned", table.spec().isUnpartitioned());
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts) AS day_col", tableName);
+ table.refresh();
+ PartitionSpec expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(1)
+ .day("ts", "day_col")
+ .build();
+ Assert.assertEquals("Should have new spec field", expected, table.spec());
+
+ sql("ALTER TABLE %s REPLACE PARTITION FIELD day_col WITH hours(ts) AS hour_col", tableName);
+ table.refresh();
+ expected = PartitionSpec.builderFor(table.schema())
+ .withSpecId(2)
+ .alwaysNull("ts", "day_col")
+ .hour("ts", "hour_col")
+ .build();
+ Assert.assertEquals("Should changed from daily to hourly partitioned field", expected, table.spec());
+ }
+
+ @Test
+ public void testSparkTableAddDropPartitions() throws Exception {
+ sql("CREATE TABLE %s (id bigint NOT NULL, ts timestamp, data string) USING iceberg", tableName);
+ Assert.assertEquals("spark table partition should be empty", 0, sparkTable().partitioning().length);
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(16, id) AS shard", tableName);
+ assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)");
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName);
+ assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)");
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName);
+ assertPartitioningEquals(sparkTable(), 3, "years(ts)");
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName);
+ assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)");
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName);
+ assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)");
+
+ sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName);
+ sql("DESCRIBE %s", tableName);
+ Assert.assertEquals("spark table partition should be empty", 0, sparkTable().partitioning().length);
+ }
+
+ private void assertPartitioningEquals(SparkTable table, int len, String transform) {
+ Assert.assertEquals("spark table partition should be " + len, len, table.partitioning().length);
+ Assert.assertEquals("latest spark table partition transform should match",
+ transform, table.partitioning()[len - 1].toString());
+ }
+
+ private SparkTable sparkTable() throws Exception {
+ validationCatalog.loadTable(tableIdent).refresh();
+ CatalogManager catalogManager = spark.sessionState().catalogManager();
+ TableCatalog catalog = (TableCatalog) catalogManager.catalog(catalogName);
+ Identifier identifier = Identifier.of(tableIdent.namespace().levels(), tableIdent.name());
+ return (SparkTable) catalog.loadTable(identifier);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java
new file mode 100644
index 0000000..ac12953
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAlterTableSchema.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestAlterTableSchema extends SparkExtensionsTestBase {
+ public TestAlterTableSchema(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTable() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testSetIdentifierFields() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, " +
+ "location struct<lon:bigint NOT NULL,lat:bigint NOT NULL> NOT NULL) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty());
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have new identifier field",
+ Sets.newHashSet(table.schema().findField("id").fieldId()),
+ table.schema().identifierFieldIds());
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have new identifier field",
+ Sets.newHashSet(
+ table.schema().findField("id").fieldId(),
+ table.schema().findField("location.lon").fieldId()),
+ table.schema().identifierFieldIds());
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS location.lon", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have new identifier field",
+ Sets.newHashSet(table.schema().findField("location.lon").fieldId()),
+ table.schema().identifierFieldIds());
+ }
+
+ @Test
+ public void testSetInvalidIdentifierFields() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, id2 bigint) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty());
+ AssertHelpers.assertThrows("should not allow setting unknown fields",
+ IllegalArgumentException.class,
+ "not found in current schema or added columns",
+ () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS unknown", tableName));
+
+ AssertHelpers.assertThrows("should not allow setting optional fields",
+ IllegalArgumentException.class,
+ "not a required field",
+ () -> sql("ALTER TABLE %s SET IDENTIFIER FIELDS id2", tableName));
+ }
+
+ @Test
+ public void testDropIdentifierFields() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, " +
+ "location struct<lon:bigint NOT NULL,lat:bigint NOT NULL> NOT NULL) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty());
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have new identifier fields",
+ Sets.newHashSet(
+ table.schema().findField("id").fieldId(),
+ table.schema().findField("location.lon").fieldId()),
+ table.schema().identifierFieldIds());
+
+ sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id", tableName);
+ table.refresh();
+ Assert.assertEquals("Should removed identifier field",
+ Sets.newHashSet(table.schema().findField("location.lon").fieldId()),
+ table.schema().identifierFieldIds());
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS id, location.lon", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have new identifier fields",
+ Sets.newHashSet(
+ table.schema().findField("id").fieldId(),
+ table.schema().findField("location.lon").fieldId()),
+ table.schema().identifierFieldIds());
+
+ sql("ALTER TABLE %s DROP IDENTIFIER FIELDS id, location.lon", tableName);
+ table.refresh();
+ Assert.assertEquals("Should have no identifier field",
+ Sets.newHashSet(),
+ table.schema().identifierFieldIds());
+ }
+
+ @Test
+ public void testDropInvalidIdentifierFields() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string NOT NULL, " +
+ "location struct<lon:bigint NOT NULL,lat:bigint NOT NULL> NOT NULL) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start without identifier", table.schema().identifierFieldIds().isEmpty());
+ AssertHelpers.assertThrows("should not allow dropping unknown fields",
+ IllegalArgumentException.class,
+ "field unknown not found",
+ () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS unknown", tableName));
+
+ sql("ALTER TABLE %s SET IDENTIFIER FIELDS id", tableName);
+ AssertHelpers.assertThrows("should not allow dropping a field that is not an identifier",
+ IllegalArgumentException.class,
+ "data is not an identifier field",
+ () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS data", tableName));
+
+ AssertHelpers.assertThrows("should not allow dropping a nested field that is not an identifier",
+ IllegalArgumentException.class,
+ "location.lon is not an identifier field",
+ () -> sql("ALTER TABLE %s DROP IDENTIFIER FIELDS location.lon", tableName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java
new file mode 100644
index 0000000..99e9d06
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.math.BigDecimal;
+import java.sql.Timestamp;
+import java.time.Instant;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.parser.ParseException;
+import org.apache.spark.sql.catalyst.parser.ParserInterface;
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException;
+import org.apache.spark.sql.catalyst.plans.logical.CallArgument;
+import org.apache.spark.sql.catalyst.plans.logical.CallStatement;
+import org.apache.spark.sql.catalyst.plans.logical.NamedArgument;
+import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import scala.collection.JavaConverters;
+
+public class TestCallStatementParser {
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ private static SparkSession spark = null;
+ private static ParserInterface parser = null;
+
+ @BeforeClass
+ public static void startSpark() {
+ TestCallStatementParser.spark = SparkSession.builder()
+ .master("local[2]")
+ .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName())
+ .config("spark.extra.prop", "value")
+ .getOrCreate();
+ TestCallStatementParser.parser = spark.sessionState().sqlParser();
+ }
+
+ @AfterClass
+ public static void stopSpark() {
+ SparkSession currentSpark = TestCallStatementParser.spark;
+ TestCallStatementParser.spark = null;
+ TestCallStatementParser.parser = null;
+ currentSpark.stop();
+ }
+
+ @Test
+ public void testCallWithPositionalArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
+ Assert.assertEquals(ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(7, call.args().size());
+
+ checkArg(call, 0, 1, DataTypes.IntegerType);
+ checkArg(call, 1, "2", DataTypes.StringType);
+ checkArg(call, 2, 3L, DataTypes.LongType);
+ checkArg(call, 3, true, DataTypes.BooleanType);
+ checkArg(call, 4, 1.0D, DataTypes.DoubleType);
+ checkArg(call, 5, 9.0e1, DataTypes.DoubleType);
+ checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1));
+ }
+
+ @Test
+ public void testCallWithNamedArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(3, call.args().size());
+
+ checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
+ checkArg(call, 1, "c2", "2", DataTypes.StringType);
+ checkArg(call, 2, "c3", true, DataTypes.BooleanType);
+ }
+
+ @Test
+ public void testCallWithMixedArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(2, call.args().size());
+
+ checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
+ checkArg(call, 1, "2", DataTypes.StringType);
+ }
+
+ @Test
+ public void testCallWithTimestampArg() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(1, call.args().size());
+
+ checkArg(call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType);
+ }
+
+ @Test
+ public void testCallWithVarSubstitution() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(1, call.args().size());
+
+ checkArg(call, 0, "value", DataTypes.StringType);
+ }
+
+ @Test
+ public void testCallParseError() {
+ AssertHelpers.assertThrows("Should fail with a sensible parse error", IcebergParseException.class,
+ "missing '(' at 'radish'",
+ () -> parser.parsePlan("CALL cat.system radish kebab"));
+ }
+
+ private void checkArg(CallStatement call, int index, Object expectedValue, DataType expectedType) {
+ checkArg(call, index, null, expectedValue, expectedType);
+ }
+
+ private void checkArg(CallStatement call, int index, String expectedName,
+ Object expectedValue, DataType expectedType) {
+
+ if (expectedName != null) {
+ NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class);
+ Assert.assertEquals(expectedName, arg.name());
+ } else {
+ CallArgument arg = call.args().apply(index);
+ checkCast(arg, PositionalArgument.class);
+ }
+
+ Expression expectedExpr = toSparkLiteral(expectedValue, expectedType);
+ Expression actualExpr = call.args().apply(index).expr();
+ Assert.assertEquals("Arg types must match", expectedExpr.dataType(), actualExpr.dataType());
+ Assert.assertEquals("Arg must match", expectedExpr, actualExpr);
+ }
+
+ private Literal toSparkLiteral(Object value, DataType dataType) {
+ return Literal$.MODULE$.create(value, dataType);
+ }
+
+ private <T> T checkCast(Object value, Class<T> expectedClass) {
+ Assert.assertTrue("Expected instance of " + expectedClass.getName(), expectedClass.isInstance(value));
+ return expectedClass.cast(value);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java
new file mode 100644
index 0000000..ba1c5db
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCherrypickSnapshotProcedure.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED;
+
+public class TestCherrypickSnapshotProcedure extends SparkExtensionsTestBase {
+
+ public TestCherrypickSnapshotProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testCherrypickSnapshotUsingPositionalArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED);
+
+ spark.conf().set("spark.wap.id", "1");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should not see rows from staged snapshot",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.cherrypick_snapshot('%s', %dL)",
+ catalogName, tableIdent, wapSnapshot.snapshotId());
+
+ table.refresh();
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Cherrypick must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testCherrypickSnapshotUsingNamedArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED);
+
+ spark.conf().set("spark.wap.id", "1");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should not see rows from staged snapshot",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.cherrypick_snapshot(snapshot_id => %dL, table => '%s')",
+ catalogName, wapSnapshot.snapshotId(), tableIdent);
+
+ table.refresh();
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(wapSnapshot.snapshotId(), currentSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Cherrypick must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testCherrypickSnapshotRefreshesRelationCache() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED);
+
+ Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should not produce rows", ImmutableList.of(), sql("SELECT * FROM tmp"));
+
+ spark.conf().set("spark.wap.id", "1");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should not see rows from staged snapshot",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots());
+
+ sql("CALL %s.system.cherrypick_snapshot('%s', %dL)",
+ catalogName, tableIdent, wapSnapshot.snapshotId());
+
+ assertEquals("Cherrypick snapshot should be visible",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM tmp"));
+
+ sql("UNCACHE TABLE tmp");
+ }
+
+ @Test
+ public void testCherrypickInvalidSnapshot() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ AssertHelpers.assertThrows("Should reject invalid snapshot id",
+ ValidationException.class, "Cannot cherry pick unknown snapshot id",
+ () -> sql("CALL %s.system.cherrypick_snapshot('%s', -1L)", catalogName, tableIdent));
+ }
+
+ @Test
+ public void testInvalidCherrypickSnapshotCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.cherrypick_snapshot('n', table => 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.cherrypick_snapshot('', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast",
+ () -> sql("CALL %s.system.cherrypick_snapshot('t', 2.2)", catalogName));
+ }
+}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java
similarity index 53%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java
index 30ea7fe..25f2fb9 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java
@@ -17,13 +17,23 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.iceberg.spark.extensions;
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import java.util.Map;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.junit.Ignore;
+
+@Ignore
+public class TestCopyOnWriteDelete extends TestDelete {
+
+ public TestCopyOnWriteDelete(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, Boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ @Override
+ protected Map<String, String> extraTableProperties() {
+ return ImmutableMap.of(TableProperties.DELETE_MODE, "copy-on-write");
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
similarity index 53%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
index 30ea7fe..f3c9ec6 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
@@ -17,13 +17,23 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.iceberg.spark.extensions;
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import java.util.Map;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.junit.Ignore;
+
+@Ignore
+public class TestCopyOnWriteMerge extends TestMerge {
+
+ public TestCopyOnWriteMerge(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ @Override
+ protected Map<String, String> extraTableProperties() {
+ return ImmutableMap.of(TableProperties.MERGE_MODE, "copy-on-write");
+ }
}
diff --git a/spark/build.gradle b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java
similarity index 53%
copy from spark/build.gradle
copy to spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java
index 30ea7fe..f5e3bc0 100644
--- a/spark/build.gradle
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java
@@ -17,13 +17,23 @@
* under the License.
*/
-// add enabled Spark version modules to the build
-def sparkVersions = (System.getProperty("sparkVersions") != null ? System.getProperty("sparkVersions") : System.getProperty("defaultSparkVersions")).split(",")
+package org.apache.iceberg.spark.extensions;
-if (jdkVersion == '8' && sparkVersions.contains("2.4")) {
- apply from: file("$projectDir/v2.4/build.gradle")
-}
+import java.util.Map;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.junit.Ignore;
+
+@Ignore
+public class TestCopyOnWriteUpdate extends TestUpdate {
+
+ public TestCopyOnWriteUpdate(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
-if (sparkVersions.contains("3.0")) {
- apply from: file("$projectDir/v3.0/build.gradle")
+ @Override
+ protected Map<String, String> extraTableProperties() {
+ return ImmutableMap.of(TableProperties.UPDATE_MODE, "copy-on-write");
+ }
}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
new file mode 100644
index 0000000..7b9dc6a
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
@@ -0,0 +1,744 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
+import org.apache.spark.SparkException;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.hamcrest.CoreMatchers;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
+import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
+import static org.apache.spark.sql.functions.lit;
+
+public abstract class TestDelete extends SparkRowLevelOperationsTestBase {
+
+ public TestDelete(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, Boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
+
+ @BeforeClass
+ public static void setupSparkConf() {
+ spark.conf().set("spark.sql.shuffle.partitions", "4");
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS deleted_id");
+ sql("DROP TABLE IF EXISTS deleted_dep");
+ }
+
+ @Test
+ public void testDeleteFromEmptyTable() {
+ createAndInitUnpartitionedTable();
+
+ sql("DELETE FROM %s WHERE id IN (1)", tableName);
+ sql("DELETE FROM %s WHERE dep = 'hr'", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testExplain() {
+ createAndInitUnpartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", tableName);
+
+ sql("EXPLAIN DELETE FROM %s WHERE true", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots()));
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithAlias() {
+ createAndInitUnpartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("DELETE FROM %s AS t WHERE t.id IS NULL", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException {
+ createAndInitPartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(3, "hr"));
+ append(new Employee(1, "hardware"), new Employee(2, "hardware"));
+
+ sql("DELETE FROM %s WHERE id = 2", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "1", "1", "1");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+ }
+
+ @Test
+ public void testDeleteNonExistingRecords() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("DELETE FROM %s AS t WHERE t.id > 10", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "0", null, null);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithoutCondition() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName);
+
+ sql("DELETE FROM %s", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots()));
+
+ // should be a delete instead of an overwrite as it is done through a metadata operation
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "delete", "2", "3", null);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testDeleteUsingMetadataWithComplexCondition() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'dep1')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'dep2')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, 'dep3')", tableName);
+
+ sql("DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots()));
+
+ // should be a delete instead of an overwrite as it is done through a metadata operation
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "delete", "2", "2", null);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "dep1")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testDeleteWithArbitraryPartitionPredicates() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName);
+
+ // %% is an escaped version of %
+ sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots()));
+
+ // should be an overwrite since cannot be executed using a metadata operation
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "1", "1", null);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithNonDeterministicCondition() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName);
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic expressions",
+ AnalysisException.class, "nondeterministic expressions are only allowed",
+ () -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", tableName));
+ }
+
+ @Test
+ public void testDeleteWithFoldableConditions() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName);
+
+ // should keep all rows and don't trigger execution
+ sql("DELETE FROM %s WHERE false", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // should keep all rows and don't trigger execution
+ sql("DELETE FROM %s WHERE 50 <> 50", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // should keep all rows and don't trigger execution
+ sql("DELETE FROM %s WHERE 1 > null", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // should remove all rows
+ sql("DELETE FROM %s WHERE 21 = 21", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+ }
+
+ @Test
+ public void testDeleteWithNullConditions() {
+ createAndInitPartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (0, null), (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ // should keep all rows as null is never equal to null
+ sql("DELETE FROM %s WHERE dep = null", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ // null = 'software' -> null
+ // should delete using metadata operation only
+ sql("DELETE FROM %s WHERE dep = 'software'", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ // should delete using metadata operation only
+ sql("DELETE FROM %s WHERE dep <=> NULL", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "delete", "1", "1", null);
+ }
+
+ @Ignore // TODO: fails due to SPARK-33267
+ public void testDeleteWithInAndNotInConditions() {
+ createAndInitUnpartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("DELETE FROM %s WHERE id IN (1, null)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id NOT IN (null, 1)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id NOT IN (1, 10)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithMultipleRowGroupsParquet() throws NoSuchTableException {
+ Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet"));
+
+ createAndInitPartitionedTable();
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100);
+
+ List<Integer> ids = new ArrayList<>();
+ for (int id = 1; id <= 200; id++) {
+ ids.add(id);
+ }
+ Dataset<Row> df = spark.createDataset(ids, Encoders.INT())
+ .withColumnRenamed("value", "id")
+ .withColumn("dep", lit("hr"));
+ df.coalesce(1).writeTo(tableName).append();
+
+ Assert.assertEquals(200, spark.table(tableName).count());
+
+ // delete a record from one of two row groups and copy over the second one
+ sql("DELETE FROM %s WHERE id IN (200, 201)", tableName);
+
+ Assert.assertEquals(199, spark.table(tableName).count());
+ }
+
+ @Test
+ public void testDeleteWithConditionOnNestedColumn() {
+ createAndInitNestedColumnsTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName);
+
+ sql("DELETE FROM %s WHERE complex.c1 = id + 2", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2)),
+ sql("SELECT id FROM %s", tableName));
+
+ sql("DELETE FROM %s t WHERE t.complex.c1 = id", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT id FROM %s", tableName));
+ }
+
+ @Test
+ public void testDeleteWithInSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ createOrReplaceView("deleted_id", Arrays.asList(0, 1, null), Encoders.INT());
+ createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id) AND dep IN (SELECT * from deleted_dep)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ append(new Employee(1, "hr"), new Employee(-1, "hr"));
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(1, "hr")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ append(new Employee(null, "hr"), new Employee(2, "hr"));
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithMultiColumnInSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr"));
+
+ List<Employee> deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr"));
+ createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class));
+
+ sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Ignore // TODO: not supported since SPARK-25154 fix is not yet available
+ public void testDeleteWithNotInSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT());
+ createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ // the file filter subquery (nested loop lef-anti join) returns 0 records
+ sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id) OR dep IN ('software', 'hr')", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s t WHERE " +
+ "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) AND " +
+ "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s t WHERE " +
+ "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) OR " +
+ "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithNotInSubqueryNotSupported() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT());
+
+ AssertHelpers.assertThrows("Should complain about NOT IN subquery",
+ AnalysisException.class, "Null-aware predicate subqueries are not currently supported",
+ () -> sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", tableName));
+ }
+
+ @Test
+ public void testDeleteOnNonIcebergTableNotSupported() throws NoSuchTableException {
+ createOrReplaceView("testtable", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Delete is not supported for non iceberg table",
+ AnalysisException.class, "DELETE is only supported with v2 tables.",
+ () -> sql("DELETE FROM %s WHERE c1 = -100", "testtable"));
+ }
+
+ @Test
+ public void testDeleteWithExistSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT());
+ createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware")),
+ sql("SELECT * FROM %s", tableName));
+
+ sql("DELETE FROM %s t WHERE " +
+ "EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) AND " +
+ "EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testDeleteWithNotExistsSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT());
+ createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ sql("DELETE FROM %s t WHERE " +
+ "NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) AND " +
+ "NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("DELETE FROM %s t WHERE NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ String subquery = "SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2";
+ sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", tableName, subquery);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteWithScalarSubquery() throws NoSuchTableException {
+ createAndInitUnpartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(1, 100, null), Encoders.INT());
+
+ sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableException {
+ createAndInitPartitionedTable();
+
+ append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr"));
+ append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops"));
+ append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr"));
+ append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops"));
+
+ createOrReplaceView("deleted_id", Arrays.asList(1, 100), Encoders.INT());
+
+ String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions");
+ try {
+ // set the num of shuffle partitions to 1 to ensure we have only 1 writing task
+ spark.conf().set("spark.sql.shuffle.partitions", "1");
+
+ sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", tableName);
+ Assert.assertEquals("Should have expected num of rows", 8L, spark.table(tableName).count());
+ } finally {
+ spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions);
+ }
+ }
+
+ @Test
+ public synchronized void testDeleteWithSerializableIsolation() throws InterruptedException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitUnpartitionedTable();
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DELETE_ISOLATION_LEVEL, "serializable");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // delete thread
+ Future<?> deleteFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("DELETE FROM %s WHERE id = 1", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ deleteFuture.get();
+ Assert.fail("Expected a validation exception");
+ } catch (ExecutionException e) {
+ Throwable sparkException = e.getCause();
+ Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class));
+ Throwable validationException = sparkException.getCause();
+ Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class));
+ String errMsg = validationException.getMessage();
+ Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain"));
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public synchronized void testDeleteWithSnapshotIsolation() throws InterruptedException, ExecutionException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitUnpartitionedTable();
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DELETE_ISOLATION_LEVEL, "snapshot");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // delete thread
+ Future<?> deleteFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("DELETE FROM %s WHERE id = 1", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ deleteFuture.get();
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public void testDeleteRefreshesRelationCache() throws NoSuchTableException {
+ createAndInitPartitionedTable();
+
+ append(new Employee(1, "hr"), new Employee(3, "hr"));
+ append(new Employee(1, "hardware"), new Employee(2, "hardware"));
+
+ Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should have correct data",
+ ImmutableList.of(row(1, "hardware"), row(1, "hr")),
+ sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+ sql("DELETE FROM %s WHERE id = 1", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "2", "2", "2");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(3, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+ assertEquals("Should refresh the relation cache",
+ ImmutableList.of(),
+ sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+ spark.sql("UNCACHE TABLE tmp");
+ }
+
+ // TODO: multiple stripes for ORC
+
+ protected void createAndInitPartitionedTable() {
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tableName);
+ initTable();
+ }
+
+ protected void createAndInitUnpartitionedTable() {
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tableName);
+ initTable();
+ }
+
+ protected void createAndInitNestedColumnsTable() {
+ sql("CREATE TABLE %s (id INT, complex STRUCT<c1:INT,c2:STRING>) USING iceberg", tableName);
+ initTable();
+ }
+
+ protected void append(Employee... employees) throws NoSuchTableException {
+ List<Employee> input = Arrays.asList(employees);
+ Dataset<Row> inputDF = spark.createDataFrame(input, Employee.class);
+ inputDF.coalesce(1).writeTo(tableName).append();
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java
new file mode 100644
index 0000000..b09d12f
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestExpireSnapshotsProcedure.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.io.IOException;
+import java.sql.Timestamp;
+import java.time.Instant;
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.spark.SparkCatalog;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.GC_ENABLED;
+
+public class TestExpireSnapshotsProcedure extends SparkExtensionsTestBase {
+
+ public TestExpireSnapshotsProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testExpireSnapshotsInEmptyTable() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ List<Object[]> output = sql(
+ "CALL %s.system.expire_snapshots('%s')",
+ catalogName, tableIdent);
+ assertEquals("Should not delete any files", ImmutableList.of(row(0L, 0L, 0L)), output);
+ }
+
+ @Test
+ public void testExpireSnapshotsUsingPositionalArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+ Timestamp secondSnapshotTimestamp = Timestamp.from(Instant.ofEpochMilli(secondSnapshot.timestampMillis()));
+
+ Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ // expire without retainLast param
+ List<Object[]> output1 = sql(
+ "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s')",
+ catalogName, tableIdent, secondSnapshotTimestamp);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(0L, 0L, 1L)),
+ output1);
+
+ table.refresh();
+
+ Assert.assertEquals("Should expire one snapshot", 1, Iterables.size(table.snapshots()));
+
+ sql("INSERT OVERWRITE %s VALUES (3, 'c')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(3L, "c"), row(4L, "d")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ waitUntilAfter(table.currentSnapshot().timestampMillis());
+
+ Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
+
+ Assert.assertEquals("Should be 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ // expire with retainLast param
+ List<Object[]> output = sql(
+ "CALL %s.system.expire_snapshots('%s', TIMESTAMP '%s', 2)",
+ catalogName, tableIdent, currentTimestamp);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(2L, 2L, 1L)),
+ output);
+ }
+
+ @Test
+ public void testExpireSnapshotUsingNamedArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Should be 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ waitUntilAfter(table.currentSnapshot().timestampMillis());
+
+ Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
+
+ List<Object[]> output = sql(
+ "CALL %s.system.expire_snapshots(" +
+ "older_than => TIMESTAMP '%s'," +
+ "table => '%s'," +
+ "retain_last => 1)",
+ catalogName, currentTimestamp, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(0L, 0L, 1L)),
+ output);
+ }
+
+ @Test
+ public void testExpireSnapshotsGCDisabled() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED);
+
+ AssertHelpers.assertThrows("Should reject call",
+ ValidationException.class, "Cannot expire snapshots: GC is disabled",
+ () -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, tableIdent));
+ }
+
+ @Test
+ public void testInvalidExpireSnapshotsCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.expire_snapshots('n', table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.expire_snapshots()", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type",
+ () -> sql("CALL %s.system.expire_snapshots('n', 2.2)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.expire_snapshots('')", catalogName));
+ }
+
+ @Test
+ public void testResolvingTableInAnotherCatalog() throws IOException {
+ String anotherCatalog = "another_" + catalogName;
+ spark.conf().set("spark.sql.catalog." + anotherCatalog, SparkCatalog.class.getName());
+ spark.conf().set("spark.sql.catalog." + anotherCatalog + ".type", "hadoop");
+ spark.conf().set("spark.sql.catalog." + anotherCatalog + ".warehouse", "file:" + temp.newFolder().toString());
+
+ sql("CREATE TABLE %s.%s (id bigint NOT NULL, data string) USING iceberg", anotherCatalog, tableIdent);
+
+ AssertHelpers.assertThrows("Should reject calls for a table in another catalog",
+ IllegalArgumentException.class, "Cannot run procedure in catalog",
+ () -> sql("CALL %s.system.expire_snapshots('%s')", catalogName, anotherCatalog + "." + tableName));
+ }
+
+ @Test
+ public void testConcurrentExpireSnapshots() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName);
+
+ Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
+ List<Object[]> output = sql(
+ "CALL %s.system.expire_snapshots(" +
+ "older_than => TIMESTAMP '%s'," +
+ "table => '%s'," +
+ "max_concurrent_deletes => %s," +
+ "retain_last => 1)",
+ catalogName, currentTimestamp, tableIdent, 4);
+ assertEquals("Expiring snapshots concurrently should succeed", ImmutableList.of(row(0L, 0L, 3L)), output);
+ }
+
+ @Test
+ public void testConcurrentExpireSnapshotsWithInvalidInput() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes = 0",
+ IllegalArgumentException.class, "max_concurrent_deletes should have value > 0",
+ () -> sql("CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)",
+ catalogName, tableIdent, 0));
+
+ AssertHelpers.assertThrows("Should throw an error when max_concurrent_deletes < 0 ",
+ IllegalArgumentException.class, "max_concurrent_deletes should have value > 0",
+ () -> sql(
+ "CALL %s.system.expire_snapshots(table => '%s', max_concurrent_deletes => %s)",
+ catalogName, tableIdent, -1));
+
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java
new file mode 100644
index 0000000..ce88814
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.math.BigDecimal;
+import java.util.Map;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.IcebergTruncateTransform;
+import org.junit.After;
+import org.junit.Test;
+
+public class TestIcebergExpressions extends SparkExtensionsTestBase {
+
+ public TestIcebergExpressions(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP VIEW IF EXISTS emp");
+ sql("DROP VIEW IF EXISTS v");
+ }
+
+ @Test
+ public void testTruncateExpressions() {
+ sql("CREATE TABLE %s ( " +
+ " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY " +
+ ") USING iceberg", tableName);
+
+ sql("CREATE TEMPORARY VIEW emp " +
+ "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) " +
+ "AS EMP(int_c, long_c, dec_c, str_c, binary_c)");
+
+ sql("INSERT INTO %s SELECT * FROM emp", tableName);
+
+ Dataset<Row> df = spark.sql("SELECT * FROM " + tableName);
+ df.select(
+ new Column(new IcebergTruncateTransform(df.col("int_c").expr(), 2)).as("int_c"),
+ new Column(new IcebergTruncateTransform(df.col("long_c").expr(), 2)).as("long_c"),
+ new Column(new IcebergTruncateTransform(df.col("dec_c").expr(), 50)).as("dec_c"),
+ new Column(new IcebergTruncateTransform(df.col("str_c").expr(), 2)).as("str_c"),
+ new Column(new IcebergTruncateTransform(df.col("binary_c").expr(), 2)).as("binary_c")
+ ).createOrReplaceTempView("v");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(100, 10000L, new BigDecimal("10.50"), "10", "12")),
+ sql("SELECT int_c, long_c, dec_c, str_c, CAST(binary_c AS STRING) FROM v"));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
new file mode 100644
index 0000000..f647d4a
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -0,0 +1,1492 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
+import org.apache.spark.SparkException;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.hamcrest.CoreMatchers;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.MERGE_CARDINALITY_CHECK_ENABLED;
+import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
+import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
+import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
+import static org.apache.spark.sql.functions.lit;
+
+public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
+
+ public TestMerge(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
+
+ @BeforeClass
+ public static void setupSparkConf() {
+ spark.conf().set("spark.sql.shuffle.partitions", "4");
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS source");
+ }
+
+ // TODO: add tests for multiple NOT MATCHED clauses when we move to Spark 3.1
+
+ @Test
+ public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() {
+ createAndInitTable("id INT, dep STRING");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 3, \"dep\": \"emp-id-3\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // new
+ row(2, "emp-id-2"), // new
+ row(3, "emp-id-3") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() {
+ createAndInitTable("id INT, dep STRING");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 3, \"dep\": \"emp-id-3\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN NOT MATCHED AND (s.id >=2) THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(2, "emp-id-2"), // new
+ row(3, "emp-id-3") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithOnlyUpdateClause() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-six\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(6, "emp-id-six") // kept
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithOnlyDeleteClause() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-one") // kept
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithAllCauses() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithAllCausesWithExplicitColumnSpecification() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET t.id = s.id, t.dep = s.dep " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT (t.id, t.dep) VALUES (s.id, s.dep)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithSourceCTE() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-two\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-3\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 5, \"dep\": \"emp-id-6\" }");
+
+ sql("WITH cte1 AS (SELECT id + 1 AS id, dep FROM source) " +
+ "MERGE INTO %s AS t USING cte1 AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 2 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 3 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(2, "emp-id-2"), // updated
+ row(3, "emp-id-3") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithSourceFromSetOps() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ String derivedSource =
+ "SELECT * FROM source WHERE id = 2 " +
+ "UNION ALL " +
+ "SELECT * FROM source WHERE id = 1 OR id = 6";
+
+ sql("MERGE INTO %s AS t USING (%s) AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName, derivedSource);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithMultipleUpdatesForTargetRow() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ String errorMsg = "a single row from the target table with multiple rows of the source table";
+ AssertHelpers.assertThrows("Should complain non iceberg target table",
+ SparkException.class, errorMsg,
+ () -> {
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+ });
+
+ assertEquals("Target should be unchanged",
+ ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testMergeWithDisabledCardinalityCheck() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ try {
+ // disable the cardinality check
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tableName, MERGE_CARDINALITY_CHECK_ENABLED, false);
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+ } finally {
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tableName, MERGE_CARDINALITY_CHECK_ENABLED, true);
+ }
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testMergeWithUnconditionalDelete() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithSingleConditionalDelete() {
+ createAndInitTable("id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ String errorMsg = "a single row from the target table with multiple rows of the source table";
+ AssertHelpers.assertThrows("Should complain non iceberg target table",
+ SparkException.class, errorMsg,
+ () -> {
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+ });
+
+ assertEquals("Target should be unchanged",
+ ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testMergeWithIdentityTransform() {
+ for (DistributionMode mode : DistributionMode.values()) {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD identity(dep)", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName());
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ removeTables();
+ }
+ }
+
+ @Test
+ public void testMergeWithDaysTransform() {
+ for (DistributionMode mode : DistributionMode.values()) {
+ createAndInitTable("id INT, ts TIMESTAMP");
+ sql("ALTER TABLE %s ADD PARTITION FIELD days(ts)", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName());
+
+ append(tableName, "id INT, ts TIMESTAMP",
+ "{ \"id\": 1, \"ts\": \"2000-01-01 00:00:00\" }\n" +
+ "{ \"id\": 6, \"ts\": \"2000-01-06 00:00:00\" }");
+
+ createOrReplaceView("source", "id INT, ts TIMESTAMP",
+ "{ \"id\": 2, \"ts\": \"2001-01-02 00:00:00\" }\n" +
+ "{ \"id\": 1, \"ts\": \"2001-01-01 00:00:00\" }\n" +
+ "{ \"id\": 6, \"ts\": \"2001-01-06 00:00:00\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "2001-01-01 00:00:00"), // updated
+ row(2, "2001-01-02 00:00:00") // new
+ );
+ assertEquals("Should have expected rows",
+ expectedRows,
+ sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", tableName));
+
+ removeTables();
+ }
+ }
+
+ @Test
+ public void testMergeWithBucketTransform() {
+ for (DistributionMode mode : DistributionMode.values()) {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(2, dep)", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName());
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ removeTables();
+ }
+ }
+
+ @Test
+ public void testMergeWithTruncateTransform() {
+ for (DistributionMode mode : DistributionMode.values()) {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD truncate(dep, 2)", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName());
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ removeTables();
+ }
+ }
+
+ @Test
+ public void testMergeIntoPartitionedAndOrderedTable() {
+ for (DistributionMode mode : DistributionMode.values()) {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+ sql("ALTER TABLE %s WRITE ORDERED BY (id)", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, WRITE_DISTRIBUTION_MODE, mode.modeName());
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView("source", "id INT, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n" +
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n" +
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql("MERGE INTO %s AS t USING source AS s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET * " +
+ "WHEN MATCHED AND t.id = 6 THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND s.id = 2 THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2") // new
+ );
+ assertEquals("Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ removeTables();
+ }
+ }
+
+ @Test
+ public void testSelfMerge() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": 1, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ sql("MERGE INTO %s t USING %s s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET v = 'x' " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT *", tableName, tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "x"), // updated
+ row(2, "v2") // kept
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithSourceAsSelfSubquery() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": 1, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source", Arrays.asList(1, null), Encoders.INT());
+
+ sql("MERGE INTO %s t USING (SELECT id AS value FROM %s r JOIN source ON r.id = source.value) s " +
+ "ON t.id == s.value " +
+ "WHEN MATCHED AND t.id = 1 THEN " +
+ " UPDATE SET v = 'x' " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES ('invalid', -1) ", tableName, tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "x"), // updated
+ row(2, "v2") // kept
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public synchronized void testMergeWithSerializableIsolation() throws InterruptedException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitTable("id INT, dep STRING");
+ createOrReplaceView("source", Collections.singletonList(1), Encoders.INT());
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, MERGE_ISOLATION_LEVEL, "serializable");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // merge thread
+ Future<?> mergeFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.value " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET dep = 'x'", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ mergeFuture.get();
+ Assert.fail("Expected a validation exception");
+ } catch (ExecutionException e) {
+ Throwable sparkException = e.getCause();
+ Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class));
+ Throwable validationException = sparkException.getCause();
+ Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class));
+ String errMsg = validationException.getMessage();
+ Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain"));
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public synchronized void testMergeWithSnapshotIsolation() throws InterruptedException, ExecutionException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitTable("id INT, dep STRING");
+ createOrReplaceView("source", Collections.singletonList(1), Encoders.INT());
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, MERGE_ISOLATION_LEVEL, "snapshot");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // merge thread
+ Future<?> mergeFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.value " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET dep = 'x'", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ mergeFuture.get();
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public void testMergeWithExtraColumnsInSource() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": 1, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"extra_col\": -1, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 3, \"extra_col\": -1, \"v\": \"v3\" }\n" +
+ "{ \"id\": 4, \"extra_col\": -1, \"v\": \"v4\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "v1_1"), // new
+ row(2, "v2"), // kept
+ row(3, "v3"), // new
+ row(4, "v4") // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithNullsInTargetAndSource() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": null, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source",
+ "{ \"id\": null, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 4, \"v\": \"v4\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(null, "v1"), // kept
+ row(null, "v1_1"), // new
+ row(2, "v2"), // kept
+ row(4, "v4") // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName));
+ }
+
+ @Test
+ public void testMergeWithNullSafeEquals() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": null, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source",
+ "{ \"id\": null, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 4, \"v\": \"v4\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id <=> source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(null, "v1_1"), // updated
+ row(2, "v2"), // kept
+ row(4, "v4") // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName));
+ }
+
+ @Test
+ public void testMergeWithNullCondition() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": null, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source",
+ "{ \"id\": null, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2_2\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id AND NULL " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(null, "v1"), // kept
+ row(null, "v1_1"), // new
+ row(2, "v2"), // kept
+ row(2, "v2_2") // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName));
+ }
+
+ @Test
+ public void testMergeWithNullActionConditions() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": 1, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2_2\" }\n" +
+ "{ \"id\": 3, \"v\": \"v3_3\" }");
+
+ // all conditions are NULL and will never match any rows
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED AND source.id = 1 AND NULL THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN MATCHED AND source.v = 'v1_1' AND NULL THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows1 = ImmutableList.of(
+ row(1, "v1"), // kept
+ row(2, "v2") // kept
+ );
+ assertEquals("Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", tableName));
+
+ // only the update and insert conditions are NULL
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED AND source.id = 1 AND NULL THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN MATCHED AND source.v = 'v1_1' THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows2 = ImmutableList.of(
+ row(2, "v2") // kept
+ );
+ assertEquals("Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", tableName));
+ }
+
+ @Test
+ public void testMergeWithMultipleMatchingActions() {
+ createAndInitTable("id INT, v STRING",
+ "{ \"id\": 1, \"v\": \"v1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2\" }");
+
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"v\": \"v1_1\" }\n" +
+ "{ \"id\": 2, \"v\": \"v2_2\" }");
+
+ // the order of match actions is important in this case
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED AND source.id = 1 THEN " +
+ " UPDATE SET v = source.v " +
+ "WHEN MATCHED AND source.v = 'v1_1' THEN " +
+ " DELETE " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (v, id) VALUES (source.v, source.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, "v1_1"), // updated (also matches the delete cond but update is first)
+ row(2, "v2") // kept (matches neither the update nor the delete cond)
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName));
+ }
+
+ @Test
+ public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException {
+ Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet"));
+
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100);
+
+ createOrReplaceView("source", Collections.singletonList(1), Encoders.INT());
+
+ List<Integer> ids = new ArrayList<>();
+ for (int id = 1; id <= 200; id++) {
+ ids.add(id);
+ }
+ Dataset<Row> df = spark.createDataset(ids, Encoders.INT())
+ .withColumnRenamed("value", "id")
+ .withColumn("dep", lit("hr"));
+ df.coalesce(1).writeTo(tableName).append();
+
+ Assert.assertEquals(200, spark.table(tableName).count());
+
+ // update a record from one of two row groups and copy over the second one
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.value " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET dep = 'x'", tableName);
+
+ Assert.assertEquals(200, spark.table(tableName).count());
+ }
+
+ @Test
+ public void testMergeInsertOnly() {
+ createAndInitTable("id STRING, v STRING",
+ "{ \"id\": \"a\", \"v\": \"v1\" }\n" +
+ "{ \"id\": \"b\", \"v\": \"v2\" }");
+ createOrReplaceView("source",
+ "{ \"id\": \"a\", \"v\": \"v1_1\" }\n" +
+ "{ \"id\": \"a\", \"v\": \"v1_2\" }\n" +
+ "{ \"id\": \"c\", \"v\": \"v3\" }\n" +
+ "{ \"id\": \"d\", \"v\": \"v4_1\" }\n" +
+ "{ \"id\": \"d\", \"v\": \"v4_2\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT *", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row("a", "v1"), // kept
+ row("b", "v2"), // kept
+ row("c", "v3"), // new
+ row("d", "v4_1"), // new
+ row("d", "v4_2") // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeInsertOnlyWithCondition() {
+ createAndInitTable("id INTEGER, v INTEGER", "{ \"id\": 1, \"v\": 1 }");
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"v\": 11, \"is_new\": true }\n" +
+ "{ \"id\": 2, \"v\": 21, \"is_new\": true }\n" +
+ "{ \"id\": 2, \"v\": 22, \"is_new\": false }");
+
+ // validate assignments are reordered to match the table attrs
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.id " +
+ "WHEN NOT MATCHED AND is_new = TRUE THEN " +
+ " INSERT (v, id) VALUES (s.v + 100, s.id)", tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1, 1), // kept
+ row(2, 121) // new
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeAlignsUpdateAndInsertActions() {
+ createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }");
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" +
+ "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }");
+
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET b = c2, a = c1, t.id = source.id " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (b, a, id) VALUES (c2, c1, id)", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeUpdatesNestedStructFields() {
+ createAndInitTable("id INT, s STRUCT<c1:INT,c2:STRUCT<a:ARRAY<INT>,m:MAP<STRING, STRING>>>",
+ "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }");
+ createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2 }");
+
+ // update primitive, array, map columns inside a struct
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.c1 = source.c1, t.s.c2.a = array(-1, -2), t.s.c2.m = map('k', 'v')", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(-2, row(ImmutableList.of(-1, -2), ImmutableMap.of("k", "v"))))),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // set primitive, array, map columns to NULL (proper casts should be in place)
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.c1 = NULL, t.s.c2 = NULL", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(null, null))),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // update all fields in a struct
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s = named_struct('c1', 100, 'c2', named_struct('a', array(1), 'm', map('x', 'y')))", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(100, row(ImmutableList.of(1), ImmutableMap.of("x", "y"))))),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeWithInferredCasts() {
+ createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }");
+ createOrReplaceView("source", "{ \"id\": 1, \"c1\": -2}");
+
+ // -2 in source should be casted to "-2" in target
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s = source.c1", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, "-2")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testMergeModifiesNullStruct() {
+ createAndInitTable("id INT, s STRUCT<n1:INT,n2:INT>", "{ \"id\": 1, \"s\": null }");
+ createOrReplaceView("source", "{ \"id\": 1, \"n1\": -10 }");
+
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.n1 = s.n1", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(-10, null))),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testMergeRefreshesRelationCache() {
+ createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }");
+ createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }");
+
+ Dataset<Row> query = spark.sql("SELECT name FROM " + tableName);
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should have correct data",
+ ImmutableList.of(row("n1")),
+ sql("SELECT * FROM tmp"));
+
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.name = s.name", tableName);
+
+ assertEquals("View should have correct data",
+ ImmutableList.of(row("n2")),
+ sql("SELECT * FROM tmp"));
+
+ spark.sql("UNCACHE TABLE tmp");
+ }
+
+ @Test
+ public void testMergeWithNonExistingColumns() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about the invalid top-level column",
+ AnalysisException.class, "cannot resolve '`t.invalid_col`'",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.invalid_col = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about the invalid nested column",
+ AnalysisException.class, "No such struct field invalid_col",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n2.invalid_col = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about the invalid top-level column",
+ AnalysisException.class, "cannot resolve '`invalid_col`'",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n2.dn1 = s.c2 " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (id, invalid_col) VALUES (s.c1, null)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithInvalidColumnsInInsert() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about the nested column",
+ AnalysisException.class, "Nested fields are not supported inside INSERT clauses",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n2.dn1 = s.c2 " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (id, c.n2) VALUES (s.c1, null)", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about duplicate columns",
+ AnalysisException.class, "Duplicate column names inside INSERT clause",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n2.dn1 = s.c2 " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (id, id) VALUES (s.c1, null)", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about missing columns",
+ AnalysisException.class, "must provide values for all columns of the target table",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (id) VALUES (s.c1)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithInvalidUpdates() {
+ createAndInitTable("id INT, a ARRAY<STRUCT<c1:INT,c2:INT>>, m MAP<STRING,STRING>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about updating an array column",
+ AnalysisException.class, "Updating nested fields is only supported for structs",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.a.c1 = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about updating a map column",
+ AnalysisException.class, "Updating nested fields is only supported for structs",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.m.key = 'new_key'", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithConflictingUpdates() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about conflicting updates to a top-level column",
+ AnalysisException.class, "Updates are in conflict",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.id = 1, t.c.n1 = 2, t.id = 2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column",
+ AnalysisException.class, "Updates are in conflict for these columns",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about conflicting updates to a nested column",
+ AnalysisException.class, "Updates are in conflict",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithInvalidAssignments() {
+ createAndInitTable("id INT NOT NULL, s STRUCT<n1:INT NOT NULL,n2:STRUCT<dn1:INT,dn2:INT>> NOT NULL");
+ createOrReplaceView(
+ "source",
+ "c1 INT, c2 STRUCT<n1:INT NOT NULL> NOT NULL, c3 STRING NOT NULL, c4 STRUCT<dn2:INT,dn1:INT>",
+ "{ \"c1\": -100, \"c2\": { \"n1\" : 1 }, \"c3\" : 'str', \"c4\": { \"dn2\": 1, \"dn2\": 2 } }");
+
+ for (String policy : new String[]{"ansi", "strict"}) {
+ withSQLConf(ImmutableMap.of("spark.sql.storeAssignmentPolicy", policy), () -> {
+
+ AssertHelpers.assertThrows("Should complain about writing nulls to a top-level column",
+ AnalysisException.class, "Cannot write nullable values to non-null column",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.id = NULL", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about writing nulls to a nested column",
+ AnalysisException.class, "Cannot write nullable values to non-null column",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.n1 = NULL", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about writing missing fields in structs",
+ AnalysisException.class, "missing fields",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about writing invalid data types",
+ AnalysisException.class, "Cannot safely cast",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.n1 = s.c3", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about writing incompatible structs",
+ AnalysisException.class, "field name does not match",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.s.n2 = s.c4", tableName);
+ });
+ });
+ }
+ }
+
+ @Test
+ public void testMergeWithNonDeterministicConditions() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic search conditions",
+ AnalysisException.class, "nondeterministic expressions are only allowed in",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 AND rand() > t.id " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n1 = -1", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic update conditions",
+ AnalysisException.class, "nondeterministic expressions are only allowed in",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND rand() > t.id THEN " +
+ " UPDATE SET t.c.n1 = -1", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic delete conditions",
+ AnalysisException.class, "nondeterministic expressions are only allowed in",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND rand() > t.id THEN " +
+ " DELETE", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic insert conditions",
+ AnalysisException.class, "nondeterministic expressions are only allowed in",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN NOT MATCHED AND rand() > c1 THEN " +
+ " INSERT (id, c) VALUES (1, null)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithAggregateExpressions() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about agg expressions in search conditions",
+ AnalysisException.class, "contains one or more unsupported",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 AND max(t.id) == 1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n1 = -1", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about agg expressions in update conditions",
+ AnalysisException.class, "contains one or more unsupported",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND sum(t.id) < 1 THEN " +
+ " UPDATE SET t.c.n1 = -1", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic delete conditions",
+ AnalysisException.class, "contains one or more unsupported",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND sum(t.id) THEN " +
+ " DELETE", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about non-deterministic insert conditions",
+ AnalysisException.class, "contains one or more unsupported",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN NOT MATCHED AND sum(c1) < 1 THEN " +
+ " INSERT (id, c) VALUES (1, null)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithSubqueriesInConditions() {
+ createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain about subquery expressions",
+ AnalysisException.class, "Subqueries are not supported in conditions",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 AND t.id < (SELECT max(c2) FROM source) " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET t.c.n1 = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about subquery expressions",
+ AnalysisException.class, "Subqueries are not supported in conditions",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND t.id < (SELECT max(c2) FROM source) THEN " +
+ " UPDATE SET t.c.n1 = s.c2", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about subquery expressions",
+ AnalysisException.class, "Subqueries are not supported in conditions",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN MATCHED AND t.id NOT IN (SELECT c2 FROM source) THEN " +
+ " DELETE", tableName);
+ });
+
+ AssertHelpers.assertThrows("Should complain about subquery expressions",
+ AnalysisException.class, "Subqueries are not supported in conditions",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.c1 " +
+ "WHEN NOT MATCHED AND s.c1 IN (SELECT c2 FROM source) THEN " +
+ " INSERT (id, c) VALUES (1, null)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithTargetColumnsInInsertCondtions() {
+ createAndInitTable("id INT, c2 INT");
+ createOrReplaceView("source", "{ \"id\": 1, \"value\": 11 }");
+
+ AssertHelpers.assertThrows("Should complain about the target column",
+ AnalysisException.class, "cannot resolve '`c2`'",
+ () -> {
+ sql("MERGE INTO %s t USING source s " +
+ "ON t.id == s.id " +
+ "WHEN NOT MATCHED AND c2 = 1 THEN " +
+ " INSERT (id, c2) VALUES (s.id, null)", tableName);
+ });
+ }
+
+ @Test
+ public void testMergeWithNonIcebergTargetTableNotSupported() {
+ createOrReplaceView("target", "{ \"c1\": -100, \"c2\": -200 }");
+ createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");
+
+ AssertHelpers.assertThrows("Should complain non iceberg target table",
+ UnsupportedOperationException.class, "MERGE INTO TABLE is not supported temporarily.",
+ () -> {
+ sql("MERGE INTO target t USING source s " +
+ "ON t.c1 == s.c1 " +
+ "WHEN MATCHED THEN " +
+ " UPDATE SET *");
+ });
+ }
+
+ /**
+ * Tests a merge where both the source and target are evaluated to be partitioned by SingePartition at planning time
+ * but DynamicFileFilterExec will return an empty target.
+ */
+ @Test
+ public void testMergeSinglePartitionPartitioning() {
+ // This table will only have a single file and a single partition
+ createAndInitTable("id INT", "{\"id\": -1}");
+
+ // Coalesce forces our source into a SinglePartition distribution
+ spark.range(0, 5).coalesce(1).createOrReplaceTempView("source");
+
+ sql("MERGE INTO %s t USING source s ON t.id = s.id " +
+ "WHEN MATCHED THEN UPDATE SET *" +
+ "WHEN NOT MATCHED THEN INSERT *",
+ tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(-1),
+ row(0),
+ row(1),
+ row(2),
+ row(3),
+ row(4)
+ );
+
+ List<Object[]> result = sql("SELECT * FROM %s ORDER BY id", tableName);
+ assertEquals("Should correctly add the non-matching rows", expectedRows, result);
+ }
+
+ @Test
+ public void testMergeEmptyTable() {
+ // This table will only have a single file and a single partition
+ createAndInitTable("id INT", null);
+
+ // Coalesce forces our source into a SinglePartition distribution
+ spark.range(0, 5).coalesce(1).createOrReplaceTempView("source");
+
+ sql("MERGE INTO %s t USING source s ON t.id = s.id " +
+ "WHEN MATCHED THEN UPDATE SET *" +
+ "WHEN NOT MATCHED THEN INSERT *",
+ tableName);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(0),
+ row(1),
+ row(2),
+ row(3),
+ row(4)
+ );
+
+ List<Object[]> result = sql("SELECT * FROM %s ORDER BY id", tableName);
+ assertEquals("Should correctly add the non-matching rows", expectedRows, result);
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
new file mode 100644
index 0000000..d66e75a
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.io.IOException;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class TestMigrateTableProcedure extends SparkExtensionsTestBase {
+
+ public TestMigrateTableProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName);
+ }
+
+ @Test
+ public void testMigrate() throws IOException {
+ Assume.assumeTrue(catalogName.equals("spark_catalog"));
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName);
+
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ Table createdTable = validationCatalog.loadTable(tableIdent);
+
+ String tableLocation = createdTable.location().replace("file:", "");
+ Assert.assertEquals("Table should have original location", location, tableLocation);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ sql("DROP TABLE %s", tableName + "_BACKUP_");
+ }
+
+ @Test
+ public void testMigrateWithOptions() throws IOException {
+ Assume.assumeTrue(catalogName.equals("spark_catalog"));
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Object result = scalarSql("CALL %s.system.migrate('%s', map('foo', 'bar'))", catalogName, tableName);
+
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ Table createdTable = validationCatalog.loadTable(tableIdent);
+
+ Map<String, String> props = createdTable.properties();
+ Assert.assertEquals("Should have extra property set", "bar", props.get("foo"));
+
+ String tableLocation = createdTable.location().replace("file:", "");
+ Assert.assertEquals("Table should have original location", location, tableLocation);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ sql("DROP TABLE %s", tableName + "_BACKUP_");
+ }
+
+ @Test
+ public void testMigrateWithInvalidMetricsConfig() throws IOException {
+ Assume.assumeTrue(catalogName.equals("spark_catalog"));
+
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location);
+
+ AssertHelpers.assertThrows("Should reject invalid metrics config",
+ ValidationException.class, "Invalid metrics config",
+ () -> {
+ String props = "map('write.metadata.metrics.column.x', 'X')";
+ sql("CALL %s.system.migrate('%s', %s)", catalogName, tableName, props);
+ });
+ }
+
+ @Test
+ public void testMigrateWithConflictingProps() throws IOException {
+ Assume.assumeTrue(catalogName.equals("spark_catalog"));
+
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", tableName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Object result = scalarSql("CALL %s.system.migrate('%s', map('migrated', 'false'))", catalogName, tableName);
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should override user value", "true", table.properties().get("migrated"));
+ }
+
+ @Test
+ public void testInvalidMigrateCases() {
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.migrate()", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type",
+ () -> sql("CALL %s.system.migrate(map('foo','bar'))", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.migrate('')", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java
new file mode 100644
index 0000000..c1a4ec5
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRemoveOrphanFilesProcedure.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.io.IOException;
+import java.sql.Timestamp;
+import java.time.Instant;
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import static org.apache.iceberg.TableProperties.GC_ENABLED;
+import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED;
+
+public class TestRemoveOrphanFilesProcedure extends SparkExtensionsTestBase {
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ public TestRemoveOrphanFilesProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTable() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS p", tableName);
+ }
+
+ @Test
+ public void testRemoveOrphanFilesInEmptyTable() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ List<Object[]> output = sql(
+ "CALL %s.system.remove_orphan_files('%s')",
+ catalogName, tableIdent);
+ assertEquals("Should be no orphan files", ImmutableList.of(), output);
+
+ assertEquals("Should have no rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testRemoveOrphanFilesInDataFolder() throws IOException {
+ if (catalogName.equals("testhadoop")) {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ } else {
+ // give a fresh location to Hive tables as Spark will not clean up the table location
+ // correctly while dropping tables through spark_catalog
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'",
+ tableName, temp.newFolder());
+ }
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ String metadataLocation = table.location() + "/metadata";
+ String dataLocation = table.location() + "/data";
+
+ // produce orphan files in the data location using parquet
+ sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", dataLocation);
+ sql("INSERT INTO TABLE p VALUES (1)");
+
+ // wait to ensure files are old enough
+ waitUntilAfter(System.currentTimeMillis());
+
+ Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
+
+ // check for orphans in the metadata folder
+ List<Object[]> output1 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s'," +
+ "location => '%s')",
+ catalogName, tableIdent, currentTimestamp, metadataLocation);
+ assertEquals("Should be no orphan files in the metadata folder", ImmutableList.of(), output1);
+
+ // check for orphans in the table location
+ List<Object[]> output2 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s')",
+ catalogName, tableIdent, currentTimestamp);
+ Assert.assertEquals("Should be orphan files in the data folder", 1, output2.size());
+
+ // the previous call should have deleted all orphan files
+ List<Object[]> output3 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s')",
+ catalogName, tableIdent, currentTimestamp);
+ Assert.assertEquals("Should be no more orphan files in the data folder", 0, output3.size());
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(2L, "b")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRemoveOrphanFilesDryRun() throws IOException {
+ if (catalogName.equals("testhadoop")) {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ } else {
+ // give a fresh location to Hive tables as Spark will not clean up the table location
+ // correctly while dropping tables through spark_catalog
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg LOCATION '%s'",
+ tableName, temp.newFolder());
+ }
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ // produce orphan files in the table location using parquet
+ sql("CREATE TABLE p (id bigint) USING parquet LOCATION '%s'", table.location());
+ sql("INSERT INTO TABLE p VALUES (1)");
+
+ // wait to ensure files are old enough
+ waitUntilAfter(System.currentTimeMillis());
+
+ Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
+
+ // check for orphans without deleting
+ List<Object[]> output1 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s'," +
+ "dry_run => true)",
+ catalogName, tableIdent, currentTimestamp);
+ Assert.assertEquals("Should be one orphan files", 1, output1.size());
+
+ // actually delete orphans
+ List<Object[]> output2 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s')",
+ catalogName, tableIdent, currentTimestamp);
+ Assert.assertEquals("Should be one orphan files", 1, output2.size());
+
+ // the previous call should have deleted all orphan files
+ List<Object[]> output3 = sql(
+ "CALL %s.system.remove_orphan_files(" +
+ "table => '%s'," +
+ "older_than => TIMESTAMP '%s')",
+ catalogName, tableIdent, currentTimestamp);
+ Assert.assertEquals("Should be no more orphan files", 0, output3.size());
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(2L, "b")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRemoveOrphanFilesGCDisabled() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'false')", tableName, GC_ENABLED);
+
+ AssertHelpers.assertThrows("Should reject call",
+ ValidationException.class, "Cannot remove orphan files: GC is disabled",
+ () -> sql("CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent));
+ }
+
+ @Test
+ public void testRemoveOrphanFilesWap() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED);
+
+ spark.conf().set("spark.wap.id", "1");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should not see rows from staged snapshot",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+
+ List<Object[]> output = sql(
+ "CALL %s.system.remove_orphan_files('%s')", catalogName, tableIdent);
+ assertEquals("Should be no orphan files", ImmutableList.of(), output);
+ }
+
+ @Test
+ public void testInvalidRemoveOrphanFilesCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.remove_orphan_files('n', table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.remove_orphan_files()", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type",
+ () -> sql("CALL %s.system.remove_orphan_files('n', 2.2)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.remove_orphan_files('')", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java
new file mode 100644
index 0000000..b04f176
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteManifestsProcedure.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED;
+
+public class TestRewriteManifestsProcedure extends SparkExtensionsTestBase {
+
+ public TestRewriteManifestsProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTable() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testRewriteManifestsInEmptyTable() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ List<Object[]> output = sql(
+ "CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(0, 0)),
+ output);
+ }
+
+ @Test
+ public void testRewriteLargeManifests() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Must have 1 manifest", 1, table.currentSnapshot().allManifests().size());
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('commit.manifest.target-size-bytes' '1')", tableName);
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rewrite_manifests('%s')", catalogName, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(1, 4)),
+ output);
+
+ table.refresh();
+
+ Assert.assertEquals("Must have 4 manifests", 4, table.currentSnapshot().allManifests().size());
+ }
+
+ @Test
+ public void testRewriteSmallManifestsWithSnapshotIdInheritance() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName);
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')", tableName, SNAPSHOT_ID_INHERITANCE_ENABLED, "true");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (3, 'c')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (4, 'd')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Must have 4 manifest", 4, table.currentSnapshot().allManifests().size());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rewrite_manifests(table => '%s')", catalogName, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(4, 1)),
+ output);
+
+ table.refresh();
+
+ Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests().size());
+ }
+
+ @Test
+ public void testRewriteSmallManifestsWithoutCaching() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Must have 2 manifest", 2, table.currentSnapshot().allManifests().size());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rewrite_manifests(use_caching => false, table => '%s')", catalogName, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(2, 1)),
+ output);
+
+ table.refresh();
+
+ Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests().size());
+ }
+
+ @Test
+ public void testRewriteManifestsCaseInsensitiveArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg PARTITIONED BY (data)", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Assert.assertEquals("Must have 2 manifest", 2, table.currentSnapshot().allManifests().size());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rewrite_manifests(usE_cAcHiNg => false, tAbLe => '%s')", catalogName, tableIdent);
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(2, 1)),
+ output);
+
+ table.refresh();
+
+ Assert.assertEquals("Must have 1 manifests", 1, table.currentSnapshot().allManifests().size());
+ }
+
+ @Test
+ public void testInvalidRewriteManifestsCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.rewrite_manifests('n', table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rewrite_manifests()", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type",
+ () -> sql("CALL %s.system.rewrite_manifests('n', 2.2)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject duplicate arg names name",
+ AnalysisException.class, "Duplicate procedure argument: table",
+ () -> sql("CALL %s.system.rewrite_manifests(table => 't', tAbLe => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.rewrite_manifests('')", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java
new file mode 100644
index 0000000..d3e6bdc
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToSnapshotProcedure.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assume;
+import org.junit.Test;
+
+public class TestRollbackToSnapshotProcedure extends SparkExtensionsTestBase {
+
+ public TestRollbackToSnapshotProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testRollbackToSnapshotUsingPositionalArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_snapshot('%s', %dL)",
+ catalogName, tableIdent, firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToSnapshotUsingNamedArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_snapshot(snapshot_id => %dL, table => '%s')",
+ catalogName, firstSnapshot.snapshotId(), tableIdent);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToSnapshotRefreshesRelationCache() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM tmp"));
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)",
+ catalogName, tableIdent, firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("View cache must be invalidated",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM tmp"));
+
+ sql("UNCACHE TABLE tmp");
+ }
+
+ @Test
+ public void testRollbackToSnapshotWithQuotedIdentifiers() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ StringBuilder quotedNamespaceBuilder = new StringBuilder();
+ for (String level : tableIdent.namespace().levels()) {
+ quotedNamespaceBuilder.append("`");
+ quotedNamespaceBuilder.append(level);
+ quotedNamespaceBuilder.append("`");
+ }
+ String quotedNamespace = quotedNamespaceBuilder.toString();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_snapshot('%s', %d)",
+ catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToSnapshotWithoutExplicitCatalog() {
+ Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName));
+
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ // use camel case intentionally to test case sensitivity
+ List<Object[]> output = sql(
+ "CALL SyStEm.rOLlBaCk_to_SnApShOt('%s', %dL)",
+ tableIdent, firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToInvalidSnapshot() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ AssertHelpers.assertThrows("Should reject invalid snapshot id",
+ ValidationException.class, "Cannot roll back to unknown snapshot id",
+ () -> sql("CALL %s.system.rollback_to_snapshot('%s', -1L)", catalogName, tableIdent));
+ }
+
+ @Test
+ public void testInvalidRollbackToSnapshotCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.rollback_to_snapshot(namespace => 'n1', table => 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_snapshot(1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_snapshot(table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast",
+ () -> sql("CALL %s.system.rollback_to_snapshot('t', 2.2)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.rollback_to_snapshot('', 1L)", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java
new file mode 100644
index 0000000..52fc12c
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRollbackToTimestampProcedure.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.time.LocalDateTime;
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assume;
+import org.junit.Test;
+
+public class TestRollbackToTimestampProcedure extends SparkExtensionsTestBase {
+
+ public TestRollbackToTimestampProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testRollbackToTimestampUsingPositionalArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+ String firstSnapshotTimestamp = LocalDateTime.now().toString();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_timestamp('%s',TIMESTAMP '%s')",
+ catalogName, tableIdent, firstSnapshotTimestamp);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToTimestampUsingNamedArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+ String firstSnapshotTimestamp = LocalDateTime.now().toString();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_timestamp(timestamp => TIMESTAMP '%s', table => '%s')",
+ catalogName, firstSnapshotTimestamp, tableIdent);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToTimestampRefreshesRelationCache() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+ String firstSnapshotTimestamp = LocalDateTime.now().toString();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM tmp"));
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_timestamp(table => '%s', timestamp => TIMESTAMP '%s')",
+ catalogName, tableIdent, firstSnapshotTimestamp);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("View cache must be invalidated",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM tmp"));
+
+ sql("UNCACHE TABLE tmp");
+ }
+
+ @Test
+ public void testRollbackToTimestampWithQuotedIdentifiers() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+ String firstSnapshotTimestamp = LocalDateTime.now().toString();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ StringBuilder quotedNamespaceBuilder = new StringBuilder();
+ for (String level : tableIdent.namespace().levels()) {
+ quotedNamespaceBuilder.append("`");
+ quotedNamespaceBuilder.append(level);
+ quotedNamespaceBuilder.append("`");
+ }
+ String quotedNamespace = quotedNamespaceBuilder.toString();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')",
+ catalogName, quotedNamespace + ".`" + tableIdent.name() + "`", firstSnapshotTimestamp);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testRollbackToTimestampWithoutExplicitCatalog() {
+ Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName));
+
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+ String firstSnapshotTimestamp = LocalDateTime.now().toString();
+
+ waitUntilAfter(firstSnapshot.timestampMillis());
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ // use camel case intentionally to test case sensitivity
+ List<Object[]> output = sql(
+ "CALL SyStEm.rOLlBaCk_to_TiMeStaMp('%s', TIMESTAMP '%s')",
+ tableIdent, firstSnapshotTimestamp);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Rollback must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testInvalidRollbackToTimestampCases() {
+ String timestamp = "TIMESTAMP '2007-12-03T10:15:30'";
+
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.rollback_to_timestamp(namespace => 'n1', 't', %s)", catalogName, timestamp));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.rollback_to_timestamp('n', 't', %s)", catalogName, timestamp));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_timestamp('t')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_timestamp(timestamp => %s)", catalogName, timestamp));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.rollback_to_timestamp(table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with extra args",
+ AnalysisException.class, "Too many arguments",
+ () -> sql("CALL %s.system.rollback_to_timestamp('n', 't', %s, 1L)", catalogName, timestamp));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type for timestamp: cannot cast",
+ () -> sql("CALL %s.system.rollback_to_timestamp('t', 2.2)", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java
new file mode 100644
index 0000000..0ea8c48
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.catalog.Namespace;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
+import org.junit.After;
+import org.junit.Assume;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.WRITE_AUDIT_PUBLISH_ENABLED;
+
+public class TestSetCurrentSnapshotProcedure extends SparkExtensionsTestBase {
+
+ public TestSetCurrentSnapshotProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testSetCurrentSnapshotUsingPositionalArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.set_current_snapshot('%s', %dL)",
+ catalogName, tableIdent, firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Set must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testSetCurrentSnapshotUsingNamedArgs() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ List<Object[]> output = sql(
+ "CALL %s.system.set_current_snapshot(snapshot_id => %dL, table => '%s')",
+ catalogName, firstSnapshot.snapshotId(), tableIdent);
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Set must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testSetCurrentSnapshotWap() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' 'true')", tableName, WRITE_AUDIT_PUBLISH_ENABLED);
+
+ spark.conf().set("spark.wap.id", "1");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should not see rows from staged snapshot",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot wapSnapshot = Iterables.getOnlyElement(table.snapshots());
+
+ List<Object[]> output = sql(
+ "CALL %s.system.set_current_snapshot(table => '%s', snapshot_id => %dL)",
+ catalogName, tableIdent, wapSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(null, wapSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Current snapshot must be set correctly",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void tesSetCurrentSnapshotWithoutExplicitCatalog() {
+ Assume.assumeTrue("Working only with the session catalog", "spark_catalog".equals(catalogName));
+
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Snapshot firstSnapshot = table.currentSnapshot();
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ table.refresh();
+
+ Snapshot secondSnapshot = table.currentSnapshot();
+
+ // use camel case intentionally to test case sensitivity
+ List<Object[]> output = sql(
+ "CALL SyStEm.sEt_cuRrEnT_sNaPsHot('%s', %dL)",
+ tableIdent, firstSnapshot.snapshotId());
+
+ assertEquals("Procedure output must match",
+ ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
+ output);
+
+ assertEquals("Set must be successful",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testSetCurrentSnapshotToInvalidSnapshot() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
+
+ Namespace namespace = tableIdent.namespace();
+ String tableName = tableIdent.name();
+
+ AssertHelpers.assertThrows("Should reject invalid snapshot id",
+ ValidationException.class, "Cannot roll back to unknown snapshot id",
+ () -> sql("CALL %s.system.set_current_snapshot('%s', -1L)", catalogName, tableIdent));
+ }
+
+ @Test
+ public void testInvalidRollbackToSnapshotCases() {
+ AssertHelpers.assertThrows("Should not allow mixed args",
+ AnalysisException.class, "Named and positional arguments cannot be mixed",
+ () -> sql("CALL %s.system.set_current_snapshot(namespace => 'n1', table => 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should not resolve procedures in arbitrary namespaces",
+ NoSuchProcedureException.class, "not found",
+ () -> sql("CALL %s.custom.set_current_snapshot('n', 't', 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type for snapshot_id: cannot cast",
+ () -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java
new file mode 100644
index 0000000..473278d
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetWriteDistributionAndOrdering.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.Map;
+import org.apache.iceberg.NullOrder;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableProperties;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.apache.iceberg.expressions.Expressions.bucket;
+
+public class TestSetWriteDistributionAndOrdering extends SparkExtensionsTestBase {
+ public TestSetWriteDistributionAndOrdering(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void removeTable() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testSetWriteOrderByColumn() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "range", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("category", NullOrder.NULLS_FIRST)
+ .asc("id", NullOrder.NULLS_FIRST)
+ .build();
+ Assert.assertEquals("Should have expected order", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteOrderByColumnWithDirection() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY category ASC, id DESC", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "range", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("category", NullOrder.NULLS_FIRST)
+ .desc("id", NullOrder.NULLS_LAST)
+ .build();
+ Assert.assertEquals("Should have expected order", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteOrderByColumnWithDirectionAndNullOrder() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY category ASC NULLS LAST, id DESC NULLS FIRST", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "range", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("category", NullOrder.NULLS_LAST)
+ .desc("id", NullOrder.NULLS_FIRST)
+ .build();
+ Assert.assertEquals("Should have expected order", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteOrderByTransform() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "range", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .desc("category")
+ .asc(bucket("id", 16))
+ .asc("id")
+ .build();
+ Assert.assertEquals("Should have expected order", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteUnordered() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY category DESC, bucket(16, id), id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "range", distributionMode);
+
+ Assert.assertNotEquals("Table must be sorted", SortOrder.unsorted(), table.sortOrder());
+
+ sql("ALTER TABLE %s WRITE UNORDERED", tableName);
+
+ table.refresh();
+
+ String newDistributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("New distribution mode must match", "none", newDistributionMode);
+
+ Assert.assertEquals("New sort order must match", SortOrder.unsorted(), table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteLocallyOrdered() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string, ts timestamp, data string) USING iceberg", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY category DESC, bucket(16, id), id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "none", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .desc("category")
+ .asc(bucket("id", 16))
+ .asc("id")
+ .build();
+ Assert.assertEquals("Sort order must match", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedByWithSort() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("id")
+ .build();
+ Assert.assertEquals("Sort order must match", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedByWithLocalSort() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("id")
+ .build();
+ Assert.assertEquals("Sort order must match", expected, table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedByAndUnordered() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedByOnly() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION UNORDERED", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedAndUnorderedInverted() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE UNORDERED DISTRIBUTED BY PARTITION", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ Assert.assertEquals("Sort order must match", SortOrder.unsorted(), table.sortOrder());
+ }
+
+ @Test
+ public void testSetWriteDistributedAndLocallyOrderedInverted() {
+ sql("CREATE TABLE %s (id bigint NOT NULL, category string) USING iceberg PARTITIONED BY (category)", tableName);
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertTrue("Table should start unsorted", table.sortOrder().isUnsorted());
+
+ sql("ALTER TABLE %s WRITE ORDERED BY id DISTRIBUTED BY PARTITION", tableName);
+
+ table.refresh();
+
+ String distributionMode = table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE);
+ Assert.assertEquals("Distribution mode must match", "hash", distributionMode);
+
+ SortOrder expected = SortOrder.builderFor(table.schema())
+ .withOrderId(1)
+ .asc("id")
+ .build();
+ Assert.assertEquals("Sort order must match", expected, table.sortOrder());
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
new file mode 100644
index 0000000..66fa8e8
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.io.IOException;
+import java.util.Map;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.AnalysisException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class TestSnapshotTableProcedure extends SparkExtensionsTestBase {
+ private static final String sourceName = "spark_catalog.default.source";
+ // Currently we can only Snapshot only out of the Spark Session Catalog
+
+ public TestSnapshotTableProcedure(String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS %S", sourceName);
+ }
+
+ @Test
+ public void testSnapshot() throws IOException {
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
+ Object result = scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName);
+
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ Table createdTable = validationCatalog.loadTable(tableIdent);
+ String tableLocation = createdTable.location();
+ Assert.assertNotEquals("Table should not have the original location", location, tableLocation);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testSnapshotWithProperties() throws IOException {
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
+ Object result = scalarSql(
+ "CALL %s.system.snapshot(source_table => '%s', table => '%s', properties => map('foo','bar'))",
+ catalogName, sourceName, tableName);
+
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ Table createdTable = validationCatalog.loadTable(tableIdent);
+
+ String tableLocation = createdTable.location();
+ Assert.assertNotEquals("Table should not have the original location", location, tableLocation);
+
+ Map<String, String> props = createdTable.properties();
+ Assert.assertEquals("Should have extra property set", "bar", props.get("foo"));
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testSnapshotWithAlternateLocation() throws IOException {
+ Assume.assumeTrue("No Snapshoting with Alternate locations with Hadoop Catalogs", !catalogName.contains("hadoop"));
+ String location = temp.newFolder().toString();
+ String snapshotLocation = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
+ Object[] result = sql(
+ "CALL %s.system.snapshot(source_table => '%s', table => '%s', location => '%s')",
+ catalogName, sourceName, tableName, snapshotLocation).get(0);
+
+ Assert.assertEquals("Should have added one file", 1L, result[0]);
+
+ String storageLocation = validationCatalog.loadTable(tableIdent).location();
+ Assert.assertEquals("Snapshot should be made at specified location", snapshotLocation, storageLocation);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a"), row(1L, "a")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testDropTable() throws IOException {
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
+
+ Object result = scalarSql("CALL %s.system.snapshot('%s', '%s')", catalogName, sourceName, tableName);
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+
+ sql("DROP TABLE %s", tableName);
+
+ assertEquals("Source table should be intact",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", sourceName));
+ }
+
+ @Test
+ public void testSnapshotWithConflictingProps() throws IOException {
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+ sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
+
+ Object result = scalarSql(
+ "CALL %s.system.snapshot(" +
+ "source_table => '%s'," +
+ "table => '%s'," +
+ "properties => map('%s', 'true', 'snapshot', 'false'))",
+ catalogName, sourceName, tableName, TableProperties.GC_ENABLED);
+ Assert.assertEquals("Should have added one file", 1L, result);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1L, "a")),
+ sql("SELECT * FROM %s", tableName));
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Map<String, String> props = table.properties();
+ Assert.assertEquals("Should override user value", "true", props.get("snapshot"));
+ Assert.assertEquals("Should override user value", "false", props.get(TableProperties.GC_ENABLED));
+ }
+
+ @Test
+ public void testInvalidSnapshotsCases() throws IOException {
+ String location = temp.newFolder().toString();
+ sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", sourceName, location);
+
+ AssertHelpers.assertThrows("Should reject calls without all required args",
+ AnalysisException.class, "Missing required parameters",
+ () -> sql("CALL %s.system.snapshot('foo')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid arg types",
+ AnalysisException.class, "Wrong arg type",
+ () -> sql("CALL %s.system.snapshot('n', 't', map('foo', 'bar'))", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with invalid map args",
+ AnalysisException.class, "cannot resolve 'map",
+ () -> sql("CALL %s.system.snapshot('%s', 'fable', 'loc', map(2, 1, 1))", catalogName, sourceName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.snapshot('', 'dest')", catalogName));
+
+ AssertHelpers.assertThrows("Should reject calls with empty table identifier",
+ IllegalArgumentException.class, "Cannot handle an empty identifier",
+ () -> sql("CALL %s.system.snapshot('src', '')", catalogName));
+ }
+}
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java
new file mode 100644
index 0000000..2f2a85e
--- /dev/null
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java
@@ -0,0 +1,899 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
+import org.apache.iceberg.spark.SparkSQLProperties;
+import org.apache.spark.SparkException;
+import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.hamcrest.CoreMatchers;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
+import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
+import static org.apache.iceberg.TableProperties.UPDATE_ISOLATION_LEVEL;
+import static org.apache.spark.sql.functions.lit;
+
+public abstract class TestUpdate extends SparkRowLevelOperationsTestBase {
+
+ public TestUpdate(String catalogName, String implementation, Map<String, String> config,
+ String fileFormat, boolean vectorized, String distributionMode) {
+ super(catalogName, implementation, config, fileFormat, vectorized, distributionMode);
+ }
+
+ @BeforeClass
+ public static void setupSparkConf() {
+ spark.conf().set("spark.sql.shuffle.partitions", "4");
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ sql("DROP TABLE IF EXISTS updated_id");
+ sql("DROP TABLE IF EXISTS updated_dep");
+ sql("DROP TABLE IF EXISTS deleted_employee");
+ }
+
+ @Test
+ public void testExplain() {
+ createAndInitTable("id INT, dep STRING");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", tableName);
+
+ sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots()));
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testUpdateEmptyTable() {
+ createAndInitTable("id INT, dep STRING");
+
+ sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", tableName);
+ sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testUpdateWithAlias() {
+ createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"a\" }");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("UPDATE %s AS t SET t.dep = 'invalid'", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "invalid")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testUpdateAlignsAssignments() {
+ createAndInitTable("id INT, c1 INT, c2 INT");
+
+ sql("INSERT INTO TABLE %s VALUES (1, 11, 111), (2, 22, 222)", tableName);
+
+ sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, 10, 109), row(2, 22, 222)),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testUpdateWithUnsupportedPartitionPredicate() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'software'), (2, 'hr')", tableName);
+
+ sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(1, "software")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Test
+ public void testUpdateWithDynamicFileFiltering() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 3, \"dep\": \"hr\" }");
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+ sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "1", "1", "1");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateNonExistingRecords() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName);
+
+ sql("UPDATE %s SET id = -1 WHERE id > 10", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "0", null, null);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public void testUpdateWithoutCondition() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName);
+
+ sql("UPDATE %s SET id = -1", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "2", "3", "2");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(-1, "hr")),
+ sql("SELECT * FROM %s ORDER BY dep ASC", tableName));
+ }
+
+ @Test
+ public void testUpdateWithNullConditions() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 0, \"dep\": null }\n" +
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+ // should not update any rows as null is never equal to null
+ sql("UPDATE %s SET id = -1 WHERE dep = NULL", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // should not update any rows the condition does not match any records
+ sql("UPDATE %s SET id = -1 WHERE dep = 'software'", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ // should update one matching row with a null-safe condition
+ sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "invalid"), row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
+ @Ignore // TODO: fails due to SPARK-33267
+ public void testUpdateWithInAndNotInConditions() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": null, \"dep\": \"hr\" }");
+
+ sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(100, "hardware"), row(100, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateWithMultipleRowGroupsParquet() throws NoSuchTableException {
+ Assume.assumeTrue(fileFormat.equalsIgnoreCase("parquet"));
+
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 100);
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, SPLIT_SIZE, 100);
+
+ List<Integer> ids = new ArrayList<>();
+ for (int id = 1; id <= 200; id++) {
+ ids.add(id);
+ }
+ Dataset<Row> df = spark.createDataset(ids, Encoders.INT())
+ .withColumnRenamed("value", "id")
+ .withColumn("dep", lit("hr"));
+ df.coalesce(1).writeTo(tableName).append();
+
+ Assert.assertEquals(200, spark.table(tableName).count());
+
+ // update a record from one of two row groups and copy over the second one
+ sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", tableName);
+
+ Assert.assertEquals(200, spark.table(tableName).count());
+ }
+
+ @Test
+ public void testUpdateNestedStructFields() {
+ createAndInitTable("id INT, s STRUCT<c1:INT,c2:STRUCT<a:ARRAY<INT>,m:MAP<STRING, STRING>>>",
+ "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }");
+
+ // update primitive, array, map columns inside a struct
+ sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(-1, row(ImmutableList.of(-1), ImmutableMap.of("k", "v"))))),
+ sql("SELECT * FROM %s", tableName));
+
+ // set primitive, array, map columns to NULL (proper casts should be in place)
+ sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(null, null))),
+ sql("SELECT * FROM %s", tableName));
+
+ // update all fields in a struct
+ sql("UPDATE %s SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null))", tableName);
+
+ assertEquals("Output should match",
+ ImmutableList.of(row(1, row(1, row(ImmutableList.of(1), null)))),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testUpdateWithUserDefinedDistribution() {
+ createAndInitTable("id INT, c2 INT, c3 INT");
+ sql("ALTER TABLE %s ADD PARTITION FIELD bucket(8, c3)", tableName);
+
+ append(tableName,
+ "{ \"id\": 1, \"c2\": 11, \"c3\": 1 }\n" +
+ "{ \"id\": 2, \"c2\": 22, \"c3\": 1 }\n" +
+ "{ \"id\": 3, \"c2\": 33, \"c3\": 1 }");
+
+ // request a global sort
+ sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName);
+ sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, 33, 1)),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ // request a local sort
+ sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName);
+ sql("UPDATE %s SET c2 = -33 WHERE id = 3", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, -33, 1)),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ // request a hash distribution + local sort
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName);
+ sql("UPDATE %s SET c2 = -11 WHERE id = 1", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, -11, 1), row(2, -22, 1), row(3, -33, 1)),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Test
+ public synchronized void testUpdateWithSerializableIsolation() throws InterruptedException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitTable("id INT, dep STRING");
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, UPDATE_ISOLATION_LEVEL, "serializable");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // update thread
+ Future<?> updateFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("UPDATE %s SET id = -1 WHERE id = 1", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < Integer.MAX_VALUE; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ updateFuture.get();
+ Assert.fail("Expected a validation exception");
+ } catch (ExecutionException e) {
+ Throwable sparkException = e.getCause();
+ Assert.assertThat(sparkException, CoreMatchers.instanceOf(SparkException.class));
+ Throwable validationException = sparkException.getCause();
+ Assert.assertThat(validationException, CoreMatchers.instanceOf(ValidationException.class));
+ String errMsg = validationException.getMessage();
+ Assert.assertThat(errMsg, CoreMatchers.containsString("Found conflicting files that can contain"));
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public synchronized void testUpdateWithSnapshotIsolation() throws InterruptedException, ExecutionException {
+ // cannot run tests with concurrency for Hadoop tables without atomic renames
+ Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop"));
+
+ createAndInitTable("id INT, dep STRING");
+
+ sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, UPDATE_ISOLATION_LEVEL, "snapshot");
+
+ ExecutorService executorService = MoreExecutors.getExitingExecutorService(
+ (ThreadPoolExecutor) Executors.newFixedThreadPool(2));
+
+ AtomicInteger barrier = new AtomicInteger(0);
+
+ // update thread
+ Future<?> updateFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("UPDATE %s SET id = -1 WHERE id = 1", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ // append thread
+ Future<?> appendFuture = executorService.submit(() -> {
+ for (int numOperations = 0; numOperations < 20; numOperations++) {
+ while (barrier.get() < numOperations * 2) {
+ sleep(10);
+ }
+ sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName);
+ barrier.incrementAndGet();
+ }
+ });
+
+ try {
+ updateFuture.get();
+ } finally {
+ appendFuture.cancel(true);
+ }
+
+ executorService.shutdown();
+ Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
+ }
+
+ @Test
+ public void testUpdateWithInferredCasts() {
+ createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }");
+
+ sql("UPDATE %s SET s = -1 WHERE id = 1", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "-1")),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testUpdateModifiesNullStruct() {
+ createAndInitTable("id INT, s STRUCT<n1:INT,n2:INT>", "{ \"id\": 1, \"s\": null }");
+
+ sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", tableName);
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, row(-1, null))),
+ sql("SELECT * FROM %s", tableName));
+ }
+
+ @Test
+ public void testUpdateRefreshesRelationCache() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 3, \"dep\": \"hr\" }");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+ Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+ query.createOrReplaceTempView("tmp");
+
+ spark.sql("CACHE TABLE tmp");
+
+ assertEquals("View should have correct data",
+ ImmutableList.of(row(1, "hardware"), row(1, "hr")),
+ sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+ sql("UPDATE %s SET id = -1 WHERE id = 1", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "2", "2", "2");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(2, "hardware"), row(3, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+ assertEquals("Should refresh the relation cache",
+ ImmutableList.of(),
+ sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+ spark.sql("UNCACHE TABLE tmp");
+ }
+
+ @Test
+ public void testUpdateWithInSubquery() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": null, \"dep\": \"hr\" }");
+
+ createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT());
+ createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ sql("UPDATE %s SET id = -1 WHERE " +
+ "id IN (SELECT * FROM updated_id) AND " +
+ "dep IN (SELECT * from updated_dep)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+ append(tableName,
+ "{ \"id\": null, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hr\" }");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+
+ sql("UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateWithInSubqueryAndDynamicFileFiltering() {
+ createAndInitTable("id INT, dep STRING");
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName);
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 3, \"dep\": \"hr\" }");
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+ createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT());
+
+ sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+ Snapshot currentSnapshot = table.currentSnapshot();
+ validateSnapshot(currentSnapshot, "overwrite", "1", "1", "1");
+
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateWithSelfSubquery() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hr\" }");
+
+ sql("UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", tableName, tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "x")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ sql("UPDATE %s SET dep = 'y' WHERE " +
+ "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", tableName, tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "y")),
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+
+ sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", tableName, tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hr"), row(-1, "y")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateWithMultiColumnInSubquery() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": null, \"dep\": \"hr\" }");
+
+ List<Employee> deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr"));
+ createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class));
+
+ sql("UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+ }
+
+ @Ignore // TODO: not supported since SPARK-25154 fix is not yet available
+ public void testUpdateWithNotInSubquery() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": null, \"dep\": \"hr\" }");
+
+ createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+ createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+ // the file filter subquery (nested loop lef-anti join) returns 0 records
+ sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+
+ sql("UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+ }
+
+ @Test
+ public void testUpdateWithNotInSubqueryNotSupported() {
+ createAndInitTable("id INT, dep STRING");
+
+ createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+
+ AssertHelpers.assertThrows("Should complain about NOT IN subquery",
+ AnalysisException.class, "Null-aware predicate subqueries are not currently supported",
+ () -> sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName));
+ }
+
+ @Test
+ public void testUpdateWithExistSubquery() {
+ createAndInitTable("id INT, dep STRING");
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+ "{ \"id\": null, \"dep\": \"hr\" }");
+
+ createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+ createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING());
+
+ sql("UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s t SET dep = 'x', id = -1 WHERE " +
+ "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+ sql("UPDATE %s t SET id = -2 WHERE " +
+ "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " +
+ "t.id IS NULL", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+ sql("UPDATE %s t SET id = 1 WHERE " +
+ "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " +
+ "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", tableName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")),
+ sql("SELECT * FROM %s ORDER BY id, dep", tableName));
... 50157 lines suppressed ...