You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ja...@apache.org on 2023/02/02 23:08:44 UTC

[iceberg] branch master updated: Spark 3.3: REPLACE BRANCH SQL implementation (#6638)

This is an automated email from the ASF dual-hosted git repository.

jackye pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b6c777369 Spark 3.3: REPLACE BRANCH SQL implementation (#6638)
8b6c777369 is described below

commit 8b6c777369e43af36d4a18a046989b0c9b354e39
Author: Amogh Jahagirdar <ja...@amazon.com>
AuthorDate: Thu Feb 2 15:08:36 2023 -0800

    Spark 3.3: REPLACE BRANCH SQL implementation (#6638)
    
    Co-authored-by: liliwei hililiwei@gmail.com
    Co-authored-by: xuwei xuwei132@huawei.com
    Co-authored-by: chidayong chidayong2@h-partners.com
---
 .../IcebergSqlExtensions.g4                        |  52 ++--
 .../IcebergSparkSqlExtensionsParser.scala          |   3 +-
 .../IcebergSqlExtensionsAstBuilder.scala           |  54 ++--
 .../{CreateBranch.scala => BranchOptions.scala}    |  15 +-
 ...ateBranch.scala => CreateOrReplaceBranch.scala} |   7 +-
 .../datasources/v2/CreateBranchExec.scala          |  70 ------
 .../datasources/v2/CreateOrReplaceBranchExec.scala |  81 ++++++
 .../v2/ExtendedDataSourceV2Strategy.scala          |   7 +-
 .../iceberg/spark/extensions/TestCreateBranch.java |  28 ++-
 .../spark/extensions/TestReplaceBranch.java        | 273 +++++++++++++++++++++
 10 files changed, 454 insertions(+), 136 deletions(-)

diff --git a/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
index d8128d3905..d1ab06f852 100644
--- a/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
+++ b/spark/v3.3/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
@@ -73,13 +73,33 @@ statement
     | ALTER TABLE multipartIdentifier WRITE writeSpec                                       #setWriteDistributionAndOrdering
     | ALTER TABLE multipartIdentifier SET IDENTIFIER_KW FIELDS fieldList                    #setIdentifierFields
     | ALTER TABLE multipartIdentifier DROP IDENTIFIER_KW FIELDS fieldList                   #dropIdentifierFields
-    | ALTER TABLE multipartIdentifier CREATE BRANCH identifier (AS OF VERSION snapshotId)? (RETAIN snapshotRefRetain snapshotRefRetainTimeUnit)? (snapshotRetentionClause)?    #createBranch
+    | ALTER TABLE multipartIdentifier createReplaceBranchClause   #createOrReplaceBranch
     ;
 
-snapshotRetentionClause
-    : WITH SNAPSHOT RETENTION numSnapshots SNAPSHOTS
-    | WITH SNAPSHOT RETENTION snapshotRetain snapshotRetainTimeUnit
-    | WITH SNAPSHOT RETENTION numSnapshots SNAPSHOTS snapshotRetain snapshotRetainTimeUnit
+createReplaceBranchClause
+    : (CREATE OR)? REPLACE BRANCH identifier branchOptions
+    | CREATE BRANCH (IF NOT EXISTS)? identifier branchOptions
+    ;
+
+branchOptions
+    : (AS OF VERSION snapshotId)? (refRetain)? (snapshotRetention)?;
+
+snapshotRetention
+    : WITH SNAPSHOT RETENTION minSnapshotsToKeep
+    | WITH SNAPSHOT RETENTION maxSnapshotAge
+    | WITH SNAPSHOT RETENTION minSnapshotsToKeep maxSnapshotAge
+    ;
+
+refRetain
+    : RETAIN number timeUnit
+    ;
+
+maxSnapshotAge
+    : number timeUnit
+    ;
+
+minSnapshotsToKeep
+    : number SNAPSHOTS
     ;
 
 writeSpec
@@ -175,7 +195,7 @@ fieldList
     ;
 
 nonReserved
-    : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | FIELD | FIRST | HOURS | LAST | NULLS | OF | ORDERED | PARTITION | TABLE | WRITE
+    : ADD | ALTER | AS | ASC | BRANCH | BY | CALL | CREATE | DAYS | DESC | DROP | EXISTS | FIELD | FIRST | HOURS | IF | LAST | NOT | NULLS | OF | OR | ORDERED | PARTITION | TABLE | WRITE
     | DISTRIBUTED | LOCALLY | MINUTES | MONTHS | UNORDERED | REPLACE | RETAIN | VERSION | WITH | IDENTIFIER_KW | FIELDS | SET | SNAPSHOT | SNAPSHOTS
     | TRUE | FALSE
     | MAP
@@ -189,22 +209,6 @@ numSnapshots
     : number
     ;
 
-snapshotRetain
-    : number
-    ;
-
-snapshotRefRetain
-    : number
-    ;
-
-snapshotRefRetainTimeUnit
-    : timeUnit
-    ;
-
-snapshotRetainTimeUnit
-    : timeUnit
-    ;
-
 timeUnit
     : DAYS
     | HOURS
@@ -222,17 +226,21 @@ DAYS: 'DAYS';
 DESC: 'DESC';
 DISTRIBUTED: 'DISTRIBUTED';
 DROP: 'DROP';
+EXISTS: 'EXISTS';
 FIELD: 'FIELD';
 FIELDS: 'FIELDS';
 FIRST: 'FIRST';
 HOURS: 'HOURS';
+IF : 'IF';
 LAST: 'LAST';
 LOCALLY: 'LOCALLY';
 MINUTES: 'MINUTES';
 MONTHS: 'MONTHS';
 CREATE: 'CREATE';
+NOT: 'NOT';
 NULLS: 'NULLS';
 OF: 'OF';
+OR: 'OR';
 ORDERED: 'ORDERED';
 PARTITION: 'PARTITION';
 REPLACE: 'REPLACE';
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
index 4c059f7c34..76af7d1ec6 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
@@ -206,7 +206,8 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
             normalized.contains("write unordered") ||
             normalized.contains("set identifier fields") ||
             normalized.contains("drop identifier fields") ||
-            normalized.contains("create branch")))
+            normalized.contains("create branch"))) ||
+            normalized.contains("replace branch")
 
   }
 
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
index 950e161f9f..d6564d6ab9 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
@@ -37,9 +37,10 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.parser.extensions.IcebergParserUtils.withOrigin
 import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._
 import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField
+import org.apache.spark.sql.catalyst.plans.logical.BranchOptions
 import org.apache.spark.sql.catalyst.plans.logical.CallArgument
 import org.apache.spark.sql.catalyst.plans.logical.CallStatement
-import org.apache.spark.sql.catalyst.plans.logical.CreateBranch
+import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch
 import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields
 import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -91,25 +92,40 @@ class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergS
       typedVisit[Transform](ctx.transform))
   }
 
-  /**
-   * Create an ADD BRANCH logical command.
-   */
-  override def visitCreateBranch(ctx: CreateBranchContext): CreateBranch = withOrigin(ctx) {
-    val snapshotRetention = Option(ctx.snapshotRetentionClause())
-
-    CreateBranch(
+  override def visitCreateOrReplaceBranch(ctx: CreateOrReplaceBranchContext): CreateOrReplaceBranch = withOrigin(ctx) {
+    val createOrReplaceBranchClause = ctx.createReplaceBranchClause()
+
+    val branchName = createOrReplaceBranchClause.identifier()
+    val branchOptionsContext = Option(createOrReplaceBranchClause.branchOptions())
+    val snapshotId = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotId()))
+      .map(_.getText.toLong)
+    val snapshotRetention =  branchOptionsContext.flatMap(branchOptions => Option(branchOptions.snapshotRetention()))
+    val minSnapshotsToKeep = snapshotRetention.flatMap(retention => Option(retention.minSnapshotsToKeep()))
+      .map(minSnapshots => minSnapshots.number().getText.toLong)
+    val maxSnapshotAgeMs = snapshotRetention
+      .flatMap(retention => Option(retention.maxSnapshotAge()))
+      .map(retention => TimeUnit.valueOf(retention.timeUnit().getText.toUpperCase(Locale.ENGLISH))
+        .toMillis(retention.number().getText.toLong))
+    val branchRetention = branchOptionsContext.flatMap(branchOptions => Option(branchOptions.refRetain()))
+    val branchRefAgeMs = branchRetention.map(retain =>
+      TimeUnit.valueOf(retain.timeUnit().getText.toUpperCase(Locale.ENGLISH)).toMillis(retain.number().getText.toLong))
+    val replace = ctx.createReplaceBranchClause().REPLACE() != null
+    val ifNotExists = createOrReplaceBranchClause.EXISTS() != null
+
+    val branchOptions = BranchOptions(
+      snapshotId,
+      minSnapshotsToKeep,
+      maxSnapshotAgeMs,
+      branchRefAgeMs
+    )
+
+    CreateOrReplaceBranch(
       typedVisit[Seq[String]](ctx.multipartIdentifier),
-      ctx.identifier().getText,
-      Option(ctx.snapshotId()).map(_.getText.toLong),
-      snapshotRetention.flatMap(s => Option(s.numSnapshots())).map(_.getText.toLong),
-      snapshotRetention.flatMap(s => Option(s.snapshotRetain())).map(retain => {
-        TimeUnit.valueOf(ctx.snapshotRetentionClause().snapshotRetainTimeUnit().getText.toUpperCase(Locale.ENGLISH))
-          .toMillis(retain.getText.toLong)
-      }),
-      Option(ctx.snapshotRefRetain()).map(retain => {
-        TimeUnit.valueOf(ctx.snapshotRefRetainTimeUnit().getText.toUpperCase(Locale.ENGLISH))
-          .toMillis(retain.getText.toLong)
-      }))
+      branchName.getText,
+      branchOptions,
+      replace,
+      ifNotExists)
+
   }
 
   /**
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala
similarity index 62%
copy from spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala
copy to spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala
index 91e2bc6f19..4d7e0a086b 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BranchOptions.scala
@@ -19,16 +19,5 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
-
-case class CreateBranch(table: Seq[String], branch: String, snapshotId: Option[Long], numSnapshots: Option[Long],
-                        snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) extends LeafCommand {
-
-  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
-  override lazy val output: Seq[Attribute] = Nil
-
-  override def simpleString(maxFields: Int): String = {
-    s"Create branch: ${branch} for table: ${table.quoted} "
-  }
-}
+case class BranchOptions (snapshotId: Option[Long], numSnapshots: Option[Long],
+                          snapshotRetain: Option[Long], snapshotRefRetain: Option[Long])
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala
similarity index 79%
rename from spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala
rename to spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala
index 91e2bc6f19..24d6bd3d91 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateBranch.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/CreateOrReplaceBranch.scala
@@ -21,14 +21,15 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
 
-case class CreateBranch(table: Seq[String], branch: String, snapshotId: Option[Long], numSnapshots: Option[Long],
-                        snapshotRetain: Option[Long], snapshotRefRetain: Option[Long]) extends LeafCommand {
+case class CreateOrReplaceBranch(table: Seq[String], branch: String,
+                                 branchOptions: BranchOptions, replace: Boolean, ifNotExists: Boolean)
+  extends LeafCommand {
 
   import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
   override lazy val output: Seq[Attribute] = Nil
 
   override def simpleString(maxFields: Int): String = {
-    s"Create branch: ${branch} for table: ${table.quoted} "
+    s"CreateOrReplaceBranch branch: ${branch} for table: ${table.quoted}"
   }
 }
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala
deleted file mode 100644
index acaab93b0b..0000000000
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateBranchExec.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import org.apache.iceberg.spark.source.SparkTable
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.CreateBranch
-import org.apache.spark.sql.connector.catalog.Identifier
-import org.apache.spark.sql.connector.catalog.TableCatalog
-
-case class CreateBranchExec(
-                             catalog: TableCatalog,
-                             ident: Identifier,
-                             createBranch: CreateBranch) extends LeafV2CommandExec {
-
-  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
-  override lazy val output: Seq[Attribute] = Nil
-
-  override protected def run(): Seq[InternalRow] = {
-    catalog.loadTable(ident) match {
-      case iceberg: SparkTable =>
-
-        val snapshotId = createBranch.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId())
-        val manageSnapshot = iceberg.table.manageSnapshots()
-          .createBranch(createBranch.branch, snapshotId)
-
-        if (createBranch.numSnapshots.nonEmpty) {
-          manageSnapshot.setMinSnapshotsToKeep(createBranch.branch, createBranch.numSnapshots.get.toInt)
-        }
-
-        if (createBranch.snapshotRetain.nonEmpty) {
-          manageSnapshot.setMaxSnapshotAgeMs(createBranch.branch, createBranch.snapshotRetain.get)
-        }
-
-        if (createBranch.snapshotRefRetain.nonEmpty) {
-          manageSnapshot.setMaxRefAgeMs(createBranch.branch, createBranch.snapshotRefRetain.get)
-        }
-
-        manageSnapshot.commit()
-
-      case table =>
-        throw new UnsupportedOperationException(s"Cannot add branch to non-Iceberg table: $table")
-    }
-
-    Nil
-  }
-
-  override def simpleString(maxFields: Int): String = {
-    s"Create branch: ${createBranch.branch} operation for table: ${ident.quoted}"
-  }
-}
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala
new file mode 100644
index 0000000000..08230afb5a
--- /dev/null
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateOrReplaceBranchExec.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.iceberg.spark.source.SparkTable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.BranchOptions
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.TableCatalog
+
+case class CreateOrReplaceBranchExec(
+                              catalog: TableCatalog,
+                              ident: Identifier,
+                              branch: String,
+                              branchOptions: BranchOptions,
+                              replace: Boolean,
+                              ifNotExists: Boolean) extends LeafV2CommandExec {
+
+  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+  override lazy val output: Seq[Attribute] = Nil
+
+  override protected def run(): Seq[InternalRow] = {
+    catalog.loadTable(ident) match {
+      case iceberg: SparkTable =>
+        val snapshotId = branchOptions.snapshotId.getOrElse(iceberg.table.currentSnapshot().snapshotId())
+        val manageSnapshots = iceberg.table().manageSnapshots()
+        if (!replace) {
+          val ref = iceberg.table().refs().get(branch);
+          if (ref != null && ifNotExists) {
+            return Nil
+          }
+
+          manageSnapshots.createBranch(branch, snapshotId)
+        } else {
+          manageSnapshots.replaceBranch(branch, snapshotId)
+        }
+
+        if (branchOptions.numSnapshots.nonEmpty) {
+          manageSnapshots.setMinSnapshotsToKeep(branch, branchOptions.numSnapshots.get.toInt)
+        }
+
+        if (branchOptions.snapshotRetain.nonEmpty) {
+          manageSnapshots.setMaxSnapshotAgeMs(branch, branchOptions.snapshotRetain.get)
+        }
+
+        if (branchOptions.snapshotRefRetain.nonEmpty) {
+          manageSnapshots.setMaxRefAgeMs(branch, branchOptions.snapshotRefRetain.get)
+        }
+
+        manageSnapshots.commit()
+
+      case table =>
+        throw new UnsupportedOperationException(s"Cannot create or replace branch on non-Iceberg table: $table")
+    }
+
+    Nil
+  }
+
+  override def simpleString(maxFields: Int): String = {
+    s"CreateOrReplace branch: ${branch} for table: ${ident.quoted}"
+  }
+}
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
index 08c1c1dae6..7e343534de 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField
 import org.apache.spark.sql.catalyst.plans.logical.Call
-import org.apache.spark.sql.catalyst.plans.logical.CreateBranch
+import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch
 import org.apache.spark.sql.catalyst.plans.logical.DeleteFromIcebergTable
 import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields
 import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField
@@ -62,8 +62,9 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy wi
     case AddPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform, name) =>
       AddPartitionFieldExec(catalog, ident, transform, name) :: Nil
 
-    case CreateBranch(IcebergCatalogAndIdentifier(catalog, ident), _, _, _, _, _) =>
-      CreateBranchExec(catalog, ident, plan.asInstanceOf[CreateBranch]) :: Nil
+    case CreateOrReplaceBranch(
+        IcebergCatalogAndIdentifier(catalog, ident), branch, branchOptions, replace, ifNotExists) =>
+      CreateOrReplaceBranchExec(catalog, ident, branch, branchOptions, replace, ifNotExists) :: Nil
 
     case DropPartitionField(IcebergCatalogAndIdentifier(catalog, ident), transform) =>
       DropPartitionFieldExec(catalog, ident, transform) :: Nil
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java
index 0379bcf7a9..42d34779ee 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateBranch.java
@@ -72,7 +72,7 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
         tableName, branchName, snapshotId, maxRefAge, minSnapshotsToKeep, maxSnapshotAge);
     table.refresh();
     SnapshotRef ref = table.refs().get(branchName);
-    Assert.assertNotNull(ref);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
     Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
     Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue());
     Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue());
@@ -91,7 +91,7 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
     sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branchName);
     table.refresh();
     SnapshotRef ref = table.refs().get(branchName);
-    Assert.assertNotNull(ref);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
     Assert.assertNull(ref.minSnapshotsToKeep());
     Assert.assertNull(ref.maxSnapshotAgeMs());
     Assert.assertNull(ref.maxRefAgeMs());
@@ -107,7 +107,7 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
         tableName, branchName, minSnapshotsToKeep);
     table.refresh();
     SnapshotRef ref = table.refs().get(branchName);
-    Assert.assertNotNull(ref);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
     Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
     Assert.assertNull(ref.maxSnapshotAgeMs());
     Assert.assertNull(ref.maxRefAgeMs());
@@ -129,6 +129,24 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
     Assert.assertNull(ref.maxRefAgeMs());
   }
 
+  @Test
+  public void testCreateBranchIfNotExists() throws NoSuchTableException {
+    long maxSnapshotAge = 2L;
+    Table table = createDefaultTableAndInsert2Row();
+    String branchName = "b1";
+    sql(
+        "ALTER TABLE %s CREATE BRANCH %s WITH SNAPSHOT RETENTION %d DAYS",
+        tableName, branchName, maxSnapshotAge);
+    sql("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS %s", tableName, branchName);
+
+    table.refresh();
+    SnapshotRef ref = table.refs().get(branchName);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
+    Assert.assertNull(ref.minSnapshotsToKeep());
+    Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue());
+    Assert.assertNull(ref.maxRefAgeMs());
+  }
+
   @Test
   public void testCreateBranchUseCustomMinSnapshotsToKeepAndMaxSnapshotAge()
       throws NoSuchTableException {
@@ -141,7 +159,7 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
         tableName, branchName, minSnapshotsToKeep, maxSnapshotAge);
     table.refresh();
     SnapshotRef ref = table.refs().get(branchName);
-    Assert.assertNotNull(ref);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
     Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
     Assert.assertEquals(TimeUnit.DAYS.toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue());
     Assert.assertNull(ref.maxRefAgeMs());
@@ -162,7 +180,7 @@ public class TestCreateBranch extends SparkExtensionsTestBase {
     sql("ALTER TABLE %s CREATE BRANCH %s RETAIN %d DAYS", tableName, branchName, maxRefAge);
     table.refresh();
     SnapshotRef ref = table.refs().get(branchName);
-    Assert.assertNotNull(ref);
+    Assert.assertEquals(table.currentSnapshot().snapshotId(), ref.snapshotId());
     Assert.assertNull(ref.minSnapshotsToKeep());
     Assert.assertNull(ref.maxSnapshotAgeMs());
     Assert.assertEquals(TimeUnit.DAYS.toMillis(maxRefAge), ref.maxRefAgeMs().longValue());
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java
new file mode 100644
index 0000000000..f97a95ff82
--- /dev/null
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceBranch.java
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.SnapshotRef;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.source.SimpleRecord;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class TestReplaceBranch extends SparkExtensionsTestBase {
+
+  private static final String[] TIME_UNITS = {"DAYS", "HOURS", "MINUTES"};
+
+  @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.SPARK.catalogName(),
+        SparkCatalogConfig.SPARK.implementation(),
+        SparkCatalogConfig.SPARK.properties()
+      }
+    };
+  }
+
+  public TestReplaceBranch(String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @After
+  public void removeTable() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+
+  @Test
+  public void testReplaceBranchFailsForTag() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    String tagName = "tag1";
+
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    table.manageSnapshots().createTag(tagName, first).commit();
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+
+    AssertHelpers.assertThrows(
+        "Cannot perform replace branch on tags",
+        IllegalArgumentException.class,
+        "Ref tag1 is a tag not a branch",
+        () -> sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, tagName, second));
+  }
+
+  @Test
+  public void testReplaceBranch() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    String branchName = "b1";
+    long expectedMaxRefAgeMs = 1000;
+    int expectedMinSnapshotsToKeep = 2;
+    long expectedMaxSnapshotAgeMs = 1000;
+    table
+        .manageSnapshots()
+        .createBranch(branchName, first)
+        .setMaxRefAgeMs(branchName, expectedMaxRefAgeMs)
+        .setMinSnapshotsToKeep(branchName, expectedMinSnapshotsToKeep)
+        .setMaxSnapshotAgeMs(branchName, expectedMaxSnapshotAgeMs)
+        .commit();
+
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+
+    sql("ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d", tableName, branchName, second);
+
+    table.refresh();
+    SnapshotRef ref = table.refs().get(branchName);
+    Assert.assertNotNull(ref);
+    Assert.assertEquals(ref.snapshotId(), second);
+    Assert.assertEquals(expectedMinSnapshotsToKeep, ref.minSnapshotsToKeep().intValue());
+    Assert.assertEquals(expectedMaxSnapshotAgeMs, ref.maxSnapshotAgeMs().longValue());
+    Assert.assertEquals(expectedMaxRefAgeMs, ref.maxRefAgeMs().longValue());
+  }
+
+  @Test
+  public void testReplaceBranchDoesNotExist() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+    Table table = validationCatalog.loadTable(tableIdent);
+
+    AssertHelpers.assertThrows(
+        "Cannot perform replace branch on branch which does not exist",
+        IllegalArgumentException.class,
+        "Branch does not exist",
+        () ->
+            sql(
+                "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d",
+                tableName, "someBranch", table.currentSnapshot().snapshotId()));
+  }
+
+  @Test
+  public void testReplaceBranchWithRetain() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    String branchName = "b1";
+    table.manageSnapshots().createBranch(branchName, first).commit();
+    SnapshotRef b1 = table.refs().get(branchName);
+    Integer minSnapshotsToKeep = b1.minSnapshotsToKeep();
+    Long maxSnapshotAgeMs = b1.maxSnapshotAgeMs();
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+
+    long maxRefAge = 10;
+    for (String timeUnit : TIME_UNITS) {
+      sql(
+          "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s",
+          tableName, branchName, second, maxRefAge, timeUnit);
+
+      table.refresh();
+      SnapshotRef ref = table.refs().get(branchName);
+      Assert.assertNotNull(ref);
+      Assert.assertEquals(ref.snapshotId(), second);
+      Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
+      Assert.assertEquals(maxSnapshotAgeMs, ref.maxSnapshotAgeMs());
+      Assert.assertEquals(
+          TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue());
+    }
+  }
+
+  @Test
+  public void testReplaceBranchWithSnapshotRetention() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+    String branchName = "b1";
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    table.manageSnapshots().createBranch(branchName, first).commit();
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+
+    Integer minSnapshotsToKeep = 2;
+    long maxSnapshotAge = 2;
+    Long maxRefAgeMs = table.refs().get(branchName).maxRefAgeMs();
+    for (String timeUnit : TIME_UNITS) {
+      sql(
+          "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s",
+          tableName, branchName, second, minSnapshotsToKeep, maxSnapshotAge, timeUnit);
+
+      table.refresh();
+      SnapshotRef ref = table.refs().get(branchName);
+      Assert.assertNotNull(ref);
+      Assert.assertEquals(ref.snapshotId(), second);
+      Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
+      Assert.assertEquals(
+          TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue());
+      Assert.assertEquals(maxRefAgeMs, ref.maxRefAgeMs());
+    }
+  }
+
+  @Test
+  public void testReplaceBranchWithRetainAndSnapshotRetention() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    String branchName = "b1";
+    table.manageSnapshots().createBranch(branchName, first).commit();
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+
+    Integer minSnapshotsToKeep = 2;
+    long maxSnapshotAge = 2;
+    long maxRefAge = 10;
+    for (String timeUnit : TIME_UNITS) {
+      sql(
+          "ALTER TABLE %s REPLACE BRANCH %s AS OF VERSION %d RETAIN %d %s WITH SNAPSHOT RETENTION %d SNAPSHOTS %d %s",
+          tableName,
+          branchName,
+          second,
+          maxRefAge,
+          timeUnit,
+          minSnapshotsToKeep,
+          maxSnapshotAge,
+          timeUnit);
+
+      table.refresh();
+      SnapshotRef ref = table.refs().get(branchName);
+      Assert.assertNotNull(ref);
+      Assert.assertEquals(ref.snapshotId(), second);
+      Assert.assertEquals(minSnapshotsToKeep, ref.minSnapshotsToKeep());
+      Assert.assertEquals(
+          TimeUnit.valueOf(timeUnit).toMillis(maxSnapshotAge), ref.maxSnapshotAgeMs().longValue());
+      Assert.assertEquals(
+          TimeUnit.valueOf(timeUnit).toMillis(maxRefAge), ref.maxRefAgeMs().longValue());
+    }
+  }
+
+  @Test
+  public void testCreateOrReplace() throws NoSuchTableException {
+    sql("CREATE TABLE %s (id INT, data STRING) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        ImmutableList.of(new SimpleRecord(1, "a"), new SimpleRecord(2, "b"));
+    Dataset<Row> df = spark.createDataFrame(records, SimpleRecord.class);
+    df.writeTo(tableName).append();
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    long first = table.currentSnapshot().snapshotId();
+    String branchName = "b1";
+    df.writeTo(tableName).append();
+    long second = table.currentSnapshot().snapshotId();
+    table.manageSnapshots().createBranch(branchName, second).commit();
+
+    sql(
+        "ALTER TABLE %s CREATE OR REPLACE BRANCH %s AS OF VERSION %d",
+        tableName, branchName, first);
+
+    table.refresh();
+    SnapshotRef ref = table.refs().get(branchName);
+    Assert.assertNotNull(ref);
+    Assert.assertEquals(ref.snapshotId(), first);
+  }
+}