You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/01/04 19:39:46 UTC
spark git commit: [SPARK-12579][SQL] Force user-specified JDBC driver
to take precedence
Repository: spark
Updated Branches:
refs/heads/master 8f659393b -> 6c83d938c
[SPARK-12579][SQL] Force user-specified JDBC driver to take precedence
Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to load (using the `driver` argument), but in the current code it's possible that the user-specified driver will not be used when it comes time to actually create a JDBC connection.
In a nutshell, the problem is that you might have multiple JDBC drivers on the classpath that claim to be able to handle the same subprotocol, so simply registering the user-provided driver class with the our `DriverRegistry` and JDBC's `DriverManager` is not sufficient to ensure that it's actually used when creating the JDBC connection.
This patch addresses this issue by first registering the user-specified driver with the DriverManager, then iterating over the driver manager's loaded drivers in order to obtain the correct driver and use it to create a connection (previously, we just called `DriverManager.getConnection()` directly).
If a user did not specify a JDBC driver to use, then we call `DriverManager.getDriver` to figure out the class of the driver to use, then pass that class's name to executors; this guards against corner-case bugs in situations where the driver and executor JVMs might have different sets of JDBC drivers on their classpaths (previously, there was the (rare) potential for `DriverManager.getConnection()` to use different drivers on the driver and executors if the user had not explicitly specified a JDBC driver class and the classpaths were different).
This patch is inspired by a similar patch that I made to the `spark-redshift` library (https://github.com/databricks/spark-redshift/pull/143), which contains its own modified fork of some of Spark's JDBC data source code (for cross-Spark-version compatibility reasons).
Author: Josh Rosen <jo...@databricks.com>
Closes #10519 from JoshRosen/jdbc-driver-precedence.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c83d938
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c83d938
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c83d938
Branch: refs/heads/master
Commit: 6c83d938cc61bd5fabaf2157fcc3936364a83f02
Parents: 8f65939
Author: Josh Rosen <jo...@databricks.com>
Authored: Mon Jan 4 10:39:42 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Jan 4 10:39:42 2016 -0800
----------------------------------------------------------------------
docs/sql-programming-guide.md | 4 +--
.../org/apache/spark/sql/DataFrameWriter.scala | 2 +-
.../datasources/jdbc/DefaultSource.scala | 3 --
.../datasources/jdbc/DriverRegistry.scala | 5 ---
.../execution/datasources/jdbc/JDBCRDD.scala | 33 +++---------------
.../datasources/jdbc/JDBCRelation.scala | 2 --
.../execution/datasources/jdbc/JdbcUtils.scala | 35 ++++++++++++++++----
7 files changed, 34 insertions(+), 50 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 3f9a831..b058833 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
<tr>
<td><code>driver</code></td>
<td>
- The class name of the JDBC driver needed to connect to this URL. This class will be loaded
- on the master and workers before running an JDBC commands to allow the driver to
- register itself with the JDBC subsystem.
+ The class name of the JDBC driver to use to connect to this URL.
</td>
</tr>
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ab36253..9f59c0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
- val conn = JdbcUtils.createConnection(url, props)
+ val conn = JdbcUtils.createConnectionFactory(url, props)()
try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
index f522303..5ae6cff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
@@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
- val driver = parameters.getOrElse("driver", null)
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
val partitionColumn = parameters.getOrElse("partitionColumn", null)
val lowerBound = parameters.getOrElse("lowerBound", null)
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)
- if (driver != null) DriverRegistry.register(driver)
-
if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
index 7ccd61e..65af397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
}
}
}
-
- def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
- case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
- case driver => driver.getClass.getCanonicalName
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 87d43ad..cb8d950 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
+import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.util.Properties
import scala.util.control.NonFatal
@@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
override def index: Int = idx
}
-
private[sql] object JDBCRDD extends Logging {
/**
@@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging {
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val dialect = JdbcDialects.get(url)
- val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
+ val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
try {
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
try {
@@ -228,36 +227,13 @@ private[sql] object JDBCRDD extends Logging {
})
}
- /**
- * Given a driver string and an url, return a function that loads the
- * specified driver string then returns a connection to the JDBC url.
- * getConnector is run on the driver code, while the function it returns
- * is run on the executor.
- *
- * @param driver - The class name of the JDBC driver for the given url, or null if the class name
- * is not necessary.
- * @param url - The JDBC url to connect to.
- *
- * @return A function that loads the driver and connects to the url.
- */
- def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
- () => {
- try {
- if (driver != null) DriverRegistry.register(driver)
- } catch {
- case e: ClassNotFoundException =>
- logWarning(s"Couldn't find class $driver", e)
- }
- DriverManager.getConnection(url, properties)
- }
- }
+
/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
- * @param driver - The class name of the JDBC driver for the given url.
* @param url - The JDBC url to connect to.
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
@@ -270,7 +246,6 @@ private[sql] object JDBCRDD extends Logging {
def scanTable(
sc: SparkContext,
schema: StructType,
- driver: String,
url: String,
properties: Properties,
fqTable: String,
@@ -281,7 +256,7 @@ private[sql] object JDBCRDD extends Logging {
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
- getConnector(driver, url, properties),
+ JdbcUtils.createConnectionFactory(url, properties),
pruneSchema(schema, requiredColumns),
fqTable,
quotedColumns,
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index f9300dc..375266f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
- val driver: String = DriverRegistry.getDriverClassName(url)
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
- driver,
url,
properties,
table,
http://git-wip-us.apache.org/repos/asf/spark/blob/6c83d938/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 46f2670..10f6506 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, PreparedStatement}
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
import java.util.Properties
+import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal
@@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
object JdbcUtils extends Logging {
/**
- * Establishes a JDBC connection.
+ * Returns a factory for creating connections to the given JDBC URL.
+ *
+ * @param url the JDBC url to connect to.
+ * @param properties JDBC connection properties.
*/
- def createConnection(url: String, connectionProperties: Properties): Connection = {
- JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
+ def createConnectionFactory(url: String, properties: Properties): () => Connection = {
+ val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
+ userSpecifiedDriverClass.foreach(DriverRegistry.register)
+ // Performing this part of the logic on the driver guards against the corner-case where the
+ // driver returned for a URL is different on the driver and executors due to classpath
+ // differences.
+ val driverClass: String = userSpecifiedDriverClass.getOrElse {
+ DriverManager.getDriver(url).getClass.getCanonicalName
+ }
+ () => {
+ userSpecifiedDriverClass.foreach(DriverRegistry.register)
+ val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
+ case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
+ case d if d.getClass.getCanonicalName == driverClass => d
+ }.getOrElse {
+ throw new IllegalStateException(
+ s"Did not find registered driver with class $driverClass")
+ }
+ driver.connect(url, properties)
+ }
}
/**
@@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
- properties: Properties = new Properties()) {
+ properties: Properties) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
- val driver: String = DriverRegistry.getDriverClassName(url)
- val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
+ val getConnection: () => Connection = createConnectionFactory(url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org