You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/06/15 17:26:52 UTC

[spark] branch master updated: [SPARK-42750][SQL][FOLLOWUP] Add INSERT OVERWRITE BY NAME statement

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 49dc2dbcbb1 [SPARK-42750][SQL][FOLLOWUP] Add INSERT OVERWRITE BY NAME statement
49dc2dbcbb1 is described below

commit 49dc2dbcbb1d763d8ee0f2034e82347eebe50d0f
Author: Jia Fan <fa...@qq.com>
AuthorDate: Fri Jun 16 01:26:20 2023 +0800

    [SPARK-42750][SQL][FOLLOWUP] Add INSERT OVERWRITE BY NAME statement
    
    ### What changes were proposed in this pull request?
    This PR follow up #40908 , after `INSERT BY NAME`, add `INSERT OVERWRITE BY NAME` statement.
    
    ### Why are the changes needed?
    Add `INSERT OVERWRITE BY NAME` statement
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add new test
    
    Closes #41496 from Hisoka-X/SPARK-42750_Follow_up_insert_overwrite.
    
    Authored-by: Jia Fan <fa...@qq.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |   7 +-
 .../spark/sql/catalyst/parser/DDLParserSuite.scala |  21 +++--
 .../org/apache/spark/sql/SQLInsertTestSuite.scala  | 101 ++++++++++++++++++++-
 .../spark/sql/hive/HiveSQLInsertTestSuite.scala    |  11 +++
 5 files changed, 128 insertions(+), 14 deletions(-)

diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index c7b238bfd2c..240310a426d 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -317,7 +317,7 @@ query
     ;
 
 insertInto
-    : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF NOT EXISTS)?)?  identifierList?         #insertOverwriteTable
+    : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF NOT EXISTS)?)?  ((BY NAME) | identifierList)? #insertOverwriteTable
     | INSERT INTO TABLE? identifierReference partitionSpec? (IF NOT EXISTS)? ((BY NAME) | identifierList)?   #insertIntoTable
     | INSERT INTO TABLE? identifierReference REPLACE whereClause                                             #insertIntoReplaceWhere
     | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat?                     #insertOverwriteHiveDir
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index a076385573e..abfe64f72e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -320,7 +320,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
             byName)
         })
       case table: InsertOverwriteTableContext =>
-        val (relationCtx, cols, partition, ifPartitionNotExists, _)
+        val (relationCtx, cols, partition, ifPartitionNotExists, byName)
         = visitInsertOverwriteTable(table)
         withIdentClause(relationCtx, ident => {
           InsertIntoStatement(
@@ -329,7 +329,8 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
             cols,
             query,
             overwrite = true,
-            ifPartitionNotExists)
+            ifPartitionNotExists,
+            byName)
         })
       case ctx: InsertIntoReplaceWhereContext =>
         withIdentClause(ctx.identifierReference, ident => {
@@ -379,7 +380,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
         dynamicPartitionKeys.keys.mkString(", "), ctx)
     }
 
-    (ctx.identifierReference, cols, partitionKeys, ctx.EXISTS() != null, false)
+    (ctx.identifierReference, cols, partitionKeys, ctx.EXISTS() != null, ctx.NAME() != null)
   }
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index f07de11727e..31fd232181a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -1690,17 +1690,22 @@ class DDLParserSuite extends AnalysisTest {
           Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
           overwrite = false, ifPartitionNotExists = false, byName = true))
     }
+
+    Seq(
+      "INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl BY NAME SELECT * FROM source",
+      "INSERT OVERWRITE testcat.ns1.ns2.tbl BY NAME SELECT * FROM source"
+    ).foreach { sql =>
+      parseCompare(sql,
+        InsertIntoStatement(
+          UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
+          Map.empty,
+          Nil,
+          Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
+          overwrite = true, ifPartitionNotExists = false, byName = true))
+    }
   }
 
   test("insert table: by name unsupported case") {
-    checkError(
-      exception = parseException("INSERT OVERWRITE TABLE t1 BY NAME SELECT * FROM tmp_view"),
-      errorClass = "PARSE_SYNTAX_ERROR",
-      parameters = Map(
-        "error" -> "'BY'",
-        "hint" -> "")
-    )
-
     checkError(
       exception = parseException(
         "INSERT INTO TABLE t1 BY NAME (c1,c2) SELECT * FROM tmp_view"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
index 1d27904bb2c..bb3125de9c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.SparkConf
-import org.apache.spark.SparkNumberFormatException
+import org.apache.spark.{SparkConf, SparkNumberFormatException, SparkThrowable}
 import org.apache.spark.sql.catalyst.expressions.Hex
 import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
 import org.apache.spark.sql.internal.SQLConf
@@ -34,6 +33,13 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
 
   def format: String
 
+  def checkV1AndV2Error(
+      exception: SparkThrowable,
+      v1ErrorClass: String,
+      v2ErrorClass: String,
+      v1Parameters: Map[String, String],
+      v2Parameters: Map[String, String]): Unit
+
   protected def createTable(
       table: String,
       cols: Seq[String],
@@ -160,6 +166,74 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
     }
   }
 
+  test("insert overwrite with column list - by name") {
+    withTable("t1") {
+      val cols = Seq("c1", "c2", "c3")
+      val df = Seq((3, 2, 1)).toDF(cols.reverse: _*)
+      createTable("t1", cols, Seq("int", "int", "int"))
+      processInsert("t1", df, overwrite = false)
+      verifyTable("t1", df.selectExpr(cols.reverse: _*))
+      processInsert("t1", df, overwrite = true, byName = true)
+      verifyTable("t1", df.selectExpr(cols: _*))
+    }
+  }
+
+  test("insert overwrite with column list - by name + partitioned table") {
+    val cols = Seq("c1", "c2", "c3", "c4")
+    val df = Seq((4, 3, 2, 1)).toDF(cols.reverse: _*)
+    withTable("t1") {
+      createTable("t1", cols, Seq("int", "int", "int", "int"), cols.takeRight(2))
+      processInsert("t1", df.selectExpr("c2", "c1", "c4"),
+        partitionExprs = Seq("c3=3", "c4"), overwrite = false)
+      verifyTable("t1", df.selectExpr("c2", "c1", "c3", "c4"))
+      processInsert("t1", df.selectExpr("c2", "c1", "c4"),
+        partitionExprs = Seq("c3=3", "c4"), overwrite = true, byName = true)
+      verifyTable("t1", df.selectExpr(cols: _*))
+    }
+
+    withTable("t1") {
+      createTable("t1", cols, Seq("int", "int", "int", "int"), cols.takeRight(2))
+      processInsert("t1", df.selectExpr("c2", "c1"),
+        partitionExprs = Seq("c3=3", "c4=4"), overwrite = false)
+      verifyTable("t1", df.selectExpr("c2", "c1", "c3", "c4"))
+      processInsert("t1", df.selectExpr("c2", "c1"),
+        partitionExprs = Seq("c3=3", "c4=4"), overwrite = true, byName = true)
+      verifyTable("t1", df.selectExpr(cols: _*))
+    }
+  }
+
+  test("insert by name: mismatch column name") {
+    withTable("t1") {
+      val cols = Seq("c1", "c2", "c3")
+      val cols2 = Seq("x1", "c2", "c3")
+      val df = Seq((3, 2, 1)).toDF(cols2.reverse: _*)
+      createTable("t1", cols, Seq("int", "int", "int"))
+      checkV1AndV2Error(
+        exception = intercept[AnalysisException] {
+          processInsert("t1", df, overwrite = false, byName = true)
+        },
+        v1ErrorClass = "_LEGACY_ERROR_TEMP_1186",
+        v2ErrorClass = "_LEGACY_ERROR_TEMP_1204",
+        v1Parameters = Map.empty[String, String],
+        v2Parameters = Map("tableName" -> "testcat.t1",
+          "errors" -> "Cannot find data for output column 'c1'")
+      )
+      val df2 = Seq((3, 2, 1, 0)).toDF(Seq("c3", "c2", "c1", "c0"): _*)
+      checkV1AndV2Error(
+        exception = intercept[AnalysisException] {
+          processInsert("t1", df2, overwrite = false, byName = true)
+        },
+        v1ErrorClass = "INSERT_COLUMN_ARITY_MISMATCH",
+        v2ErrorClass = "INSERT_COLUMN_ARITY_MISMATCH",
+        v1Parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`",
+          "reason" -> "too many data columns", "tableColumns" -> "'c1', 'c2', 'c3'",
+          "dataColumns" -> "'c3', 'c2', 'c1', 'c0'"),
+        v2Parameters = Map("tableName" -> "testcat.t1", "reason" -> "too many data columns",
+          "tableColumns" -> "'c1', 'c2', 'c3'", "dataColumns" -> "'c3', 'c2', 'c1', 'c0'")
+      )
+    }
+  }
+
   test("insert with column list - table output reorder + partitioned table") {
     val cols = Seq("c1", "c2", "c3", "c4")
     val df = Seq((1, 2, 3, 4)).toDF(cols: _*)
@@ -440,16 +514,38 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
 }
 
 class FileSourceSQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession {
+
   override def format: String = "parquet"
+
+  override def checkV1AndV2Error(
+      exception: SparkThrowable,
+      v1ErrorClass: String,
+      v2ErrorClass: String,
+      v1Parameters: Map[String, String],
+      v2Parameters: Map[String, String]): Unit = {
+    checkError(exception = exception, sqlState = None, errorClass = v1ErrorClass,
+      parameters = v1Parameters)
+  }
+
   override protected def sparkConf: SparkConf = {
     super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, format)
   }
+
 }
 
 class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession {
 
   override def format: String = "foo"
 
+  override def checkV1AndV2Error(
+      exception: SparkThrowable,
+      v1ErrorClass: String,
+      v2ErrorClass: String,
+      v1Parameters: Map[String, String],
+      v2Parameters: Map[String, String]): Unit = {
+    checkError(exception = exception, sqlState = None, errorClass = v2ErrorClass,
+      parameters = v2Parameters)
+  }
   protected override def sparkConf: SparkConf = {
     super.sparkConf
       .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName)
@@ -467,4 +563,5 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession
         parameters = Map("staticName" -> "c"))
     }
   }
+
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala
index 0b1d511f085..d6ba38359f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.hive
 
+import org.apache.spark.SparkThrowable
 import org.apache.spark.sql.SQLInsertTestSuite
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 
@@ -37,4 +38,14 @@ class HiveSQLInsertTestSuite extends SQLInsertTestSuite with TestHiveSingleton {
   }
 
   override def format: String = "hive OPTIONS(fileFormat='parquet')"
+
+  override def checkV1AndV2Error(
+      exception: SparkThrowable,
+      v1ErrorClass: String,
+      v2ErrorClass: String,
+      v1Parameters: Map[String, String],
+      v2Parameters: Map[String, String]): Unit = {
+    checkError(exception = exception, sqlState = None, errorClass = v1ErrorClass,
+      parameters = v1Parameters)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org