You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2016/12/30 18:27:28 UTC

spark git commit: [SPARK-18123][SQL] Use db column names instead of RDD column ones during JDBC Writing

Repository: spark
Updated Branches:
  refs/heads/master 852782b83 -> b85e29437


[SPARK-18123][SQL] Use db column names instead of RDD column ones during JDBC Writing

## What changes were proposed in this pull request?

Apache Spark supports the following cases **by quoting RDD column names** while saving through JDBC.
- Allow reserved keyword as a column name, e.g., 'order'.
- Allow mixed-case colume names like the following, e.g., `[a: int, A: int]`.

  ``` scala
  scala> val df = sql("select 1 a, 1 A")
  df: org.apache.spark.sql.DataFrame = [a: int, A: int]
  ...
  scala> df.write.mode("overwrite").format("jdbc").options(option).save()
  scala> df.write.mode("append").format("jdbc").options(option).save()
  ```

This PR aims to use **database column names** instead of RDD column ones in order to support the following additionally.
Note that this case succeeds with `MySQL`, but fails on `Postgres`/`Oracle` before.

``` scala
val df1 = sql("select 1 a")
val df2 = sql("select 1 A")
...
df1.write.mode("overwrite").format("jdbc").options(option).save()
df2.write.mode("append").format("jdbc").options(option).save()
```
## How was this patch tested?

Pass the Jenkins test with a new testcase.

Author: Dongjoon Hyun <do...@apache.org>
Author: gatorsmile <ga...@gmail.com>

Closes #15664 from dongjoon-hyun/SPARK-18123.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b85e2943
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b85e2943
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b85e2943

Branch: refs/heads/master
Commit: b85e29437d570118f5980a1d6ba56c1f06a3dfd1
Parents: 852782b
Author: Dongjoon Hyun <do...@apache.org>
Authored: Fri Dec 30 10:27:14 2016 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Fri Dec 30 10:27:14 2016 -0800

----------------------------------------------------------------------
 .../datasources/jdbc/JdbcRelationProvider.scala | 11 +--
 .../execution/datasources/jdbc/JdbcUtils.scala  | 74 +++++++++++++++-----
 .../apache/spark/sql/jdbc/JDBCWriteSuite.scala  | 35 +++++++--
 3 files changed, 95 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b85e2943/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 74f397c..e39d936 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
     val table = jdbcOptions.table
     val createTableOptions = jdbcOptions.createTableOptions
     val isTruncate = jdbcOptions.isTruncate
+    val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
 
     val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
     try {
@@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider
             if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
               // In this case, we should truncate table and then load.
               truncateTable(conn, table)
-              saveTable(df, url, table, jdbcOptions)
+              val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
+              saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
             } else {
               // Otherwise, do not truncate the table, instead drop and recreate it
               dropTable(conn, table)
               createTable(df.schema, url, table, createTableOptions, conn)
-              saveTable(df, url, table, jdbcOptions)
+              saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
             }
 
           case SaveMode.Append =>
-            saveTable(df, url, table, jdbcOptions)
+            val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
+            saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
 
           case SaveMode.ErrorIfExists =>
             throw new AnalysisException(
@@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
         }
       } else {
         createTable(df.schema, url, table, createTableOptions, conn)
-        saveTable(df, url, table, jdbcOptions)
+        saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
       }
     } finally {
       conn.close()

http://git-wip-us.apache.org/repos/asf/spark/blob/b85e2943/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index ff29a15..b138494 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
 import org.apache.spark.TaskContext
 import org.apache.spark.executor.InputMetrics
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
@@ -108,14 +108,36 @@ object JdbcUtils extends Logging {
   }
 
   /**
-   * Returns a PreparedStatement that inserts a row into table via conn.
+   * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
    */
-  def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
-      : PreparedStatement = {
-    val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+  def getInsertStatement(
+      table: String,
+      rddSchema: StructType,
+      tableSchema: Option[StructType],
+      isCaseSensitive: Boolean,
+      dialect: JdbcDialect): String = {
+    val columns = if (tableSchema.isEmpty) {
+      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+    } else {
+      val columnNameEquality = if (isCaseSensitive) {
+        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+      } else {
+        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+      }
+      // The generated insert statement needs to follow rddSchema's column sequence and
+      // tableSchema's column names. When appending data into some case-sensitive DBMSs like
+      // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
+      // RDD column names for user convenience.
+      val tableColumnNames = tableSchema.get.fieldNames
+      rddSchema.fields.map { col =>
+        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
+          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
+        }
+        dialect.quoteIdentifier(normalizedName)
+      }.mkString(",")
+    }
     val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
-    val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
-    conn.prepareStatement(sql)
+    s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   }
 
   /**
@@ -211,6 +233,26 @@ object JdbcUtils extends Logging {
   }
 
   /**
+   * Returns the schema if the table already exists in the JDBC database.
+   */
+  def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = {
+    val dialect = JdbcDialects.get(url)
+
+    try {
+      val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
+      try {
+        Some(getSchema(statement.executeQuery(), dialect))
+      } catch {
+        case _: SQLException => None
+      } finally {
+        statement.close()
+      }
+    } catch {
+      case _: SQLException => None
+    }
+  }
+
+  /**
    * Takes a [[ResultSet]] and returns its Catalyst schema.
    *
    * @return A [[StructType]] giving the Catalyst schema.
@@ -531,7 +573,7 @@ object JdbcUtils extends Logging {
       table: String,
       iterator: Iterator[Row],
       rddSchema: StructType,
-      nullTypes: Array[Int],
+      insertStmt: String,
       batchSize: Int,
       dialect: JdbcDialect,
       isolationLevel: Int): Iterator[Byte] = {
@@ -568,9 +610,9 @@ object JdbcUtils extends Logging {
         conn.setAutoCommit(false) // Everything in the same db transaction.
         conn.setTransactionIsolation(finalIsolationLevel)
       }
-      val stmt = insertStatement(conn, table, rddSchema, dialect)
-      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
-        .map(makeSetter(conn, dialect, _)).toArray
+      val stmt = conn.prepareStatement(insertStmt)
+      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
+      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
       val numFields = rddSchema.fields.length
 
       try {
@@ -657,16 +699,16 @@ object JdbcUtils extends Logging {
       df: DataFrame,
       url: String,
       table: String,
+      tableSchema: Option[StructType],
+      isCaseSensitive: Boolean,
       options: JDBCOptions): Unit = {
     val dialect = JdbcDialects.get(url)
-    val nullTypes: Array[Int] = df.schema.fields.map { field =>
-      getJdbcType(field.dataType, dialect).jdbcNullType
-    }
-
     val rddSchema = df.schema
     val getConnection: () => Connection = createConnectionFactory(options)
     val batchSize = options.batchSize
     val isolationLevel = options.isolationLevel
+
+    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
     val repartitionedDF = options.numPartitions match {
       case Some(n) if n <= 0 => throw new IllegalArgumentException(
         s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
@@ -675,7 +717,7 @@ object JdbcUtils extends Logging {
       case _ => df
     }
     repartitionedDF.foreachPartition(iterator => savePartition(
-      getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
+      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b85e2943/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index f49ac23..354af29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter
 
 import org.scalatest.BeforeAndAfter
 
-import org.apache.spark.SparkException
-import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
 import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
       StructField("id", IntegerType) ::
       StructField("seq", IntegerType) :: Nil)
 
+  private lazy val schema4 = StructType(
+      StructField("NAME", StringType) ::
+      StructField("ID", IntegerType) :: Nil)
+
   test("Basic CREATE") {
     val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
 
@@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
     assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
   }
 
+  test("SPARK-18123 Append with column names with different cases") {
+    val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+    val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4)
+
+    df.write.jdbc(url, "TEST.APPENDTEST", new Properties())
+
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      val m = intercept[AnalysisException] {
+        df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
+      }.getMessage
+      assert(m.contains("Column \"NAME\" not found"))
+    }
+
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
+      assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
+      assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
+    }
+  }
+
   test("Truncate") {
     JdbcDialects.registerDialect(testH2Dialect)
     val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
@@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
     assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
     assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
 
-    val m = intercept[SparkException] {
+    val m = intercept[AnalysisException] {
       df3.write.mode(SaveMode.Overwrite).option("truncate", true)
         .jdbc(url1, "TEST.TRUNCATETEST", properties)
     }.getMessage
@@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
     val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
 
     df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
-    intercept[org.apache.spark.SparkException] {
+    val m = intercept[AnalysisException] {
       df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
-    }
+    }.getMessage
+    assert(m.contains("Column \"seq\" not found"))
   }
 
   test("INSERT to JDBC Datasource") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org