You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/07/19 08:49:03 UTC

[spark] branch master updated: [SPARK-36163][SQL] Propagate correct JDBC properties in JDBC connector provider and add "connectionProvider" option

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4036ad9  [SPARK-36163][SQL] Propagate correct JDBC properties in JDBC connector provider and add "connectionProvider" option
4036ad9 is described below

commit 4036ad9ad9d2fd0e5e5fb8b9a86bf7f4e408b1b9
Author: Ivan Sadikov <iv...@databricks.com>
AuthorDate: Mon Jul 19 17:48:32 2021 +0900

    [SPARK-36163][SQL] Propagate correct JDBC properties in JDBC connector provider and add "connectionProvider" option
    
    ### What changes were proposed in this pull request?
    
    This PR fixes two issues highlighted in https://issues.apache.org/jira/browse/SPARK-36163:
    - JDBC connection provider propagates incorrect connection properties.
    - Ambiguity when more than one JDBC connection provider is available.
    
    I updated `BasicConnectionProvider` to use `jdbcOptions.asConnectionProperties` to remove JDBC data source specific options.
    
    I also added `connectionProvider` data source option that specifies the name of the provider, e.g. `db2`, `presto`, to allow enforcing this specific provider in case of ambiguity.
    
    ### Why are the changes needed?
    Users can leverage `spark.sql.sources.disabledJdbcConnProviderList` but it is cumbersome and requires them to disable all other providers which could be problematic when using ambiguous providers in two or more different JDBC queries.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    PROBLEM DESCRIPTION:
    This introduces new JDBC data source option `connectionProvider` that allows users to select a specific JDBC connection provider based on the short name. I updated the SQL guide doc and README.
    
    Before this change, the only way to resolve ambiguity was SQL conf to blacklist all of the other JDBC connection providers. After this change users will be able to specify the exact connection provider they need per data source.
    
    ### How was this patch tested?
    
    I updated the existing `ConnectionProviderSuite` and added a new `BasicConnectionProviderSuite`.
    
    Closes #33370 from sadikovi/fix-jdbc-conn-provider.
    
    Authored-by: Ivan Sadikov <iv...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 docs/sql-data-sources-jdbc.md                      | 19 ++++-
 .../execution/datasources/jdbc/JDBCOptions.scala   |  4 +
 .../sql/execution/datasources/jdbc/JdbcUtils.scala |  3 +-
 .../jdbc/connection/BasicConnectionProvider.scala  |  2 +-
 .../jdbc/connection/ConnectionProvider.scala       | 42 +++++++++--
 .../main/scala/org/apache/spark/sql/jdbc/README.md | 10 ++-
 .../connection/BasicConnectionProviderSuite.scala  | 57 ++++++++++++++
 .../jdbc/connection/ConnectionProviderSuite.scala  | 88 +++++++++++++++++++++-
 8 files changed, 209 insertions(+), 16 deletions(-)

diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index c973e8a..6d44a22 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -9,9 +9,9 @@ license: |
   The ASF licenses this file to You under the Apache License, Version 2.0
   (the "License"); you may not use this file except in compliance with
   the License.  You may obtain a copy of the License at
- 
+
      http://www.apache.org/licenses/LICENSE-2.0
- 
+
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -191,7 +191,7 @@ logging into the data sources.
     <td>write</td>
    </td>
   </tr>
-  
+
   <tr>
     <td><code>cascadeTruncate</code></td>
     <td>the default cascading truncate behaviour of the JDBC database in question, specified in the <code>isCascadeTruncate</code> in each JDBCDialect</td>
@@ -275,11 +275,22 @@ logging into the data sources.
     </td>
     <td>read/write</td>
   </tr>  
+
+  <tr>
+    <td><code>connectionProvider</code></td>
+    <td>(none)</td>
+    <td>
+      The name of the JDBC connection provider to use to connect to this URL, e.g. <code>db2</code>, <code>mssql</code>.
+      Must be one of the providers loaded with the JDBC data source. Used to disambiguate when more than one provider can handle
+      the specified driver and options. The selected provider must not be disabled by <code>spark.sql.sources.disabledJdbcConnProviderList</code>. 
+    </td>
+    <td>read/write</td>
+ </tr>  
 </table>
 
 Note that kerberos authentication with keytab is not always supported by the JDBC driver.<br>
 Before using <code>keytab</code> and <code>principal</code> configuration options, please make sure the following requirements are met:
-* The included JDBC driver version supports kerberos authentication with keytab. 
+* The included JDBC driver version supports kerberos authentication with keytab.
 * There is a built-in connection provider which supports the used database.
 
 There is a built-in connection providers for the following databases:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 97d4f2d..e3baafb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -207,6 +207,9 @@ class JDBCOptions(
   val tableComment = parameters.getOrElse(JDBC_TABLE_COMMENT, "").toString
 
   val refreshKrb5Config = parameters.getOrElse(JDBC_REFRESH_KRB5_CONFIG, "false").toBoolean
+
+  // User specified JDBC connection provider name
+  val connectionProviderName = parameters.get(JDBC_CONNECTION_PROVIDER)
 }
 
 class JdbcOptionsInWrite(
@@ -263,4 +266,5 @@ object JDBCOptions {
   val JDBC_PRINCIPAL = newOption("principal")
   val JDBC_TABLE_COMMENT = newOption("tableComment")
   val JDBC_REFRESH_KRB5_CONFIG = newOption("refreshKrb5Config")
+  val JDBC_CONNECTION_PROVIDER = newOption("connectionProvider")
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 60fcaf9..7b555bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -61,7 +61,8 @@ object JdbcUtils extends Logging {
     () => {
       DriverRegistry.register(driverClass)
       val driver: Driver = DriverRegistry.get(driverClass)
-      val connection = ConnectionProvider.create(driver, options.parameters)
+      val connection =
+        ConnectionProvider.create(driver, options.parameters, options.connectionProviderName)
       require(connection != null,
         s"The driver could not open a JDBC connection. Check the URL: ${options.url}")
       connection
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
index 890205f..66854f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala
@@ -42,7 +42,7 @@ private[jdbc] class BasicConnectionProvider extends JdbcConnectionProvider with
   override def getConnection(driver: Driver, options: Map[String, String]): Connection = {
     val jdbcOptions = new JDBCOptions(options)
     val properties = getAdditionalProperties(jdbcOptions)
-    jdbcOptions.asProperties.asScala.foreach { case(k, v) =>
+    jdbcOptions.asConnectionProperties.asScala.foreach { case(k, v) =>
       properties.put(k, v)
     }
     logDebug(s"JDBC connection initiated with URL: ${jdbcOptions.url} and properties: $properties")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
index fbc6970..e3d8275 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
@@ -25,12 +25,13 @@ import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.security.SecurityConfigurationLock
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.jdbc.JdbcConnectionProvider
 import org.apache.spark.util.Utils
 
-private[jdbc] object ConnectionProvider extends Logging {
-  private val providers = loadProviders()
+protected abstract class ConnectionProviderBase extends Logging {
+  protected val providers = loadProviders()
 
   def loadProviders(): Seq[JdbcConnectionProvider] = {
     val loader = ServiceLoader.load(classOf[JdbcConnectionProvider],
@@ -55,17 +56,42 @@ private[jdbc] object ConnectionProvider extends Logging {
     providers.filterNot(p => disabledProviders.contains(p.name)).toSeq
   }
 
-  def create(driver: Driver, options: Map[String, String]): Connection = {
+  def create(
+      driver: Driver,
+      options: Map[String, String],
+      connectionProviderName: Option[String]): Connection = {
     val filteredProviders = providers.filter(_.canHandle(driver, options))
-    require(filteredProviders.size == 1,
-      "JDBC connection initiated but not exactly one connection provider found which can handle " +
-        s"it. Found active providers: ${filteredProviders.mkString(", ")}")
+
+    if (filteredProviders.isEmpty) {
+      throw new IllegalArgumentException(
+        "Empty list of JDBC connection providers for the specified driver and options")
+    }
+
+    val selectedProvider = connectionProviderName match {
+      case Some(providerName) =>
+        // It is assumed that no two providers will have the same name
+        filteredProviders.filter(_.name == providerName).headOption.getOrElse {
+          throw new IllegalArgumentException(
+            s"Could not find a JDBC connection provider with name '$providerName' " +
+            "that can handle the specified driver and options. " +
+            s"Available providers are ${providers.mkString("[", ", ", "]")}")
+        }
+      case None =>
+        if (filteredProviders.size != 1) {
+          throw new IllegalArgumentException(
+            "JDBC connection initiated but more than one connection provider was found. Use " +
+            s"'${JDBCOptions.JDBC_CONNECTION_PROVIDER}' option to select a specific provider. " +
+            s"Found active providers ${filteredProviders.mkString("[", ", ", "]")}")
+        }
+        filteredProviders.head
+    }
+
     SecurityConfigurationLock.synchronized {
       // Inside getConnection it's safe to get parent again because SecurityConfigurationLock
       // makes sure it's untouched
       val parent = Configuration.getConfiguration
       try {
-        filteredProviders.head.getConnection(driver, options)
+        selectedProvider.getConnection(driver, options)
       } finally {
         logDebug("Restoring original security configuration")
         Configuration.setConfiguration(parent)
@@ -73,3 +99,5 @@ private[jdbc] object ConnectionProvider extends Logging {
     }
   }
 }
+
+private[jdbc] object ConnectionProvider extends ConnectionProviderBase
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md
index f8a4ae0..72196be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md
@@ -6,9 +6,9 @@ license: |
   The ASF licenses this file to You under the Apache License, Version 2.0
   (the "License"); you may not use this file except in compliance with
   the License.  You may obtain a copy of the License at
- 
+
      http://www.apache.org/licenses/LICENSE-2.0
- 
+
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -46,6 +46,12 @@ so they can be turned off and can be replaced with custom implementation. All CP
 which must be unique. One can set the following configuration entry in `SparkConf` to turn off CPs:
 `spark.sql.sources.disabledJdbcConnProviderList=name1,name2`.
 
+## How to enforce a specific JDBC connection provider?
+
+When more than one JDBC connection provider can handle a specific driver and options, it is possible to
+disambiguate and enforce a particular CP for the JDBC data source. One can set the DataFrame
+option `connectionProvider` to specify the name of the CP they want to use.
+
 ## How a JDBC connection provider found when new connection initiated?
 
 When a Spark source initiates JDBC connection it looks for a CP which supports the included driver,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala
new file mode 100644
index 0000000..823fdca
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc.connection
+
+import java.sql.{Connection, Driver}
+import java.util.Properties
+
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.when
+import org.mockito.invocation.InvocationOnMock
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+
+class BasicConnectionProviderSuite extends ConnectionProviderSuiteBase with MockitoSugar {
+  test("Check properties of BasicConnectionProvider") {
+    val opts = options("jdbc:postgresql://localhost/postgres")
+    val provider = new BasicConnectionProvider()
+    assert(provider.name == "basic")
+    assert(provider.getAdditionalProperties(opts).isEmpty())
+  }
+
+  test("Check that JDBC options don't contain data source configs") {
+    val provider = new BasicConnectionProvider()
+    val driver = mock[Driver]
+    when(driver.connect(any(), any())).thenAnswer((invocation: InvocationOnMock) => {
+      val props = invocation.getArguments().apply(1).asInstanceOf[Properties]
+      val conn = mock[Connection]
+      when(conn.getClientInfo()).thenReturn(props)
+      conn
+    })
+
+    val opts = Map(
+      JDBCOptions.JDBC_URL -> "jdbc:postgresql://localhost/postgres",
+      JDBCOptions.JDBC_TABLE_NAME -> "table",
+      JDBCOptions.JDBC_CONNECTION_PROVIDER -> "basic")
+    val conn = provider.getConnection(driver, opts)
+    assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_URL))
+    assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_TABLE_NAME))
+    assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_CONNECTION_PROVIDER))
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala
index 32d8fce..6674483 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala
@@ -17,13 +17,21 @@
 
 package org.apache.spark.sql.execution.datasources.jdbc.connection
 
+import java.sql.{Connection, Driver}
 import javax.security.auth.login.Configuration
 
+import org.scalatestplus.mockito.MockitoSugar
+
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.internal.StaticSQLConf
+import org.apache.spark.sql.jdbc.JdbcConnectionProvider
 import org.apache.spark.sql.test.SharedSparkSession
 
-class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSparkSession {
+class ConnectionProviderSuite
+  extends ConnectionProviderSuiteBase
+  with SharedSparkSession
+  with MockitoSugar {
+
   test("All built-in providers must be loaded") {
     IntentionallyFaultyConnectionProvider.constructed = false
     val providers = ConnectionProvider.loadProviders()
@@ -38,6 +46,84 @@ class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSpa
     assert(providers.size === 6)
   }
 
+  test("Throw an error selecting from an empty list of providers on create") {
+    val providerBase = new ConnectionProviderBase() {
+      override val providers = Seq.empty
+    }
+
+    val err1 = intercept[IllegalArgumentException] {
+      providerBase.create(mock[Driver], Map.empty, None)
+    }
+    assert(err1.getMessage.contains("Empty list of JDBC connection providers"))
+
+    val err2 = intercept[IllegalArgumentException] {
+      providerBase.create(mock[Driver], Map.empty, Some("test"))
+    }
+    assert(err2.getMessage.contains("Empty list of JDBC connection providers"))
+  }
+
+  test("Throw an error when more than one provider is available on create") {
+    val provider1 = new JdbcConnectionProvider() {
+      override val name: String = "test1"
+      override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
+      override def getConnection(driver: Driver, options: Map[String, String]): Connection =
+        throw new RuntimeException()
+    }
+    val provider2 = new JdbcConnectionProvider() {
+      override val name: String = "test2"
+      override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
+      override def getConnection(driver: Driver, options: Map[String, String]): Connection =
+        throw new RuntimeException()
+    }
+
+    val providerBase = new ConnectionProviderBase() {
+      override val providers = Seq(provider1, provider2)
+    }
+
+    val err = intercept[IllegalArgumentException] {
+      providerBase.create(mock[Driver], Map.empty, None)
+    }
+    assert(err.getMessage.contains("more than one connection provider was found"))
+  }
+
+  test("Handle user specified JDBC connection provider") {
+    val provider1 = new JdbcConnectionProvider() {
+      override val name: String = "test1"
+      override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
+      override def getConnection(driver: Driver, options: Map[String, String]): Connection =
+        throw new RuntimeException()
+    }
+    val provider2 = new JdbcConnectionProvider() {
+      override val name: String = "test2"
+      override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
+      override def getConnection(driver: Driver, options: Map[String, String]): Connection =
+        mock[Connection]
+    }
+
+    val providerBase = new ConnectionProviderBase() {
+      override val providers = Seq(provider1, provider2)
+    }
+    // We don't expect any exceptions or null here
+    assert(providerBase.create(mock[Driver], Map.empty, Some("test2")).isInstanceOf[Connection])
+  }
+
+  test("Throw an error when user specified provider that does not exist") {
+    val provider = new JdbcConnectionProvider() {
+      override val name: String = "provider"
+      override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
+      override def getConnection(driver: Driver, options: Map[String, String]): Connection =
+        throw new RuntimeException()
+    }
+
+    val providerBase = new ConnectionProviderBase() {
+      override val providers = Seq(provider)
+    }
+    val err = intercept[IllegalArgumentException] {
+      providerBase.create(mock[Driver], Map.empty, Some("test"))
+    }
+    assert(err.getMessage.contains("Could not find a JDBC connection provider with name 'test'"))
+  }
+
   test("Multiple security configs must be reachable") {
     Configuration.setConfiguration(null)
     val postgresProvider = new PostgresConnectionProvider()

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org