You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/05/18 20:55:46 UTC

spark git commit: [SPARK-6888] [SQL] Make the jdbc driver handling user-definable

Repository: spark
Updated Branches:
  refs/heads/master 563bfcc1a -> e1ac2a955


[SPARK-6888] [SQL] Make the jdbc driver handling user-definable

Replace the DriverQuirks with JdbcDialect(s) (and MySQLDialect/PostgresDialect)
and allow developers to change the dialects on the fly (for new JDBCRRDs only).

Some types (like an unsigned 64bit number) can be trivially mapped to java.
The status quo is that the RRD will fail to load.
This patch makes it possible to overwrite the type mapping to read e.g.
64Bit numbers as strings and handle them afterwards in software.

JDBCSuite has an example that maps all types to String, which should always
work (at the cost of extra code afterwards).

As a side effect it should now be possible to develop simple dialects
out-of-tree and even with spark-shell.

Author: Rene Treffer <tr...@measite.de>

Closes #5555 from rtreffer/jdbc-dialects and squashes the following commits:

3cbafd7 [Rene Treffer] [SPARK-6888] ignore classes belonging to changed API in MIMA report
fe7e2e8 [Rene Treffer] [SPARK-6888] Make the jdbc driver handling user-definable


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1ac2a95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1ac2a95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1ac2a95

Branch: refs/heads/master
Commit: e1ac2a955be64b8df197195e3b225271cfa8201f
Parents: 563bfcc
Author: Rene Treffer <tr...@measite.de>
Authored: Mon May 18 11:55:36 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon May 18 11:55:36 2015 -0700

----------------------------------------------------------------------
 project/MimaExcludes.scala                      |   8 +
 .../apache/spark/sql/jdbc/DriverQuirks.scala    |  99 ---------
 .../org/apache/spark/sql/jdbc/JDBCRDD.scala     |  11 +-
 .../apache/spark/sql/jdbc/JdbcDialects.scala    | 211 +++++++++++++++++++
 .../scala/org/apache/spark/sql/jdbc/jdbc.scala  |  43 ++--
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |  49 +++++
 6 files changed, 295 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 487062a..513bbaf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -137,6 +137,14 @@ object MimaExcludes {
             // implementing this interface in Java. Note that ShuffleWriter is private[spark].
             ProblemFilters.exclude[IncompatibleTemplateDefProblem](
               "org.apache.spark.shuffle.ShuffleWriter")
+          ) ++ Seq(
+            // SPARK-6888 make jdbc driver handling user definable
+            // This patch renames some classes to API friendly names.
+            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
+            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
+            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
+            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
+            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
           )
 
         case v if v.startsWith("1.3") =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
deleted file mode 100644
index 0feabc4..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.jdbc
-
-import org.apache.spark.sql.types._
-
-import java.sql.Types
-
-
-/**
- * Encapsulates workarounds for the extensions, quirks, and bugs in various
- * databases.  Lots of databases define types that aren't explicitly supported
- * by the JDBC spec.  Some JDBC drivers also report inaccurate
- * information---for instance, BIT(n>1) being reported as a BIT type is quite
- * common, even though BIT in JDBC is meant for single-bit values.  Also, there
- * does not appear to be a standard name for an unbounded string or binary
- * type; we use BLOB and CLOB by default but override with database-specific
- * alternatives when these are absent or do not behave correctly.
- *
- * Currently, the only thing DriverQuirks does is handle type mapping.
- * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
- * is used when writing to a JDBC table.  If `getCatalystType` returns `null`,
- * the default type handling is used for the given JDBC type.  Similarly,
- * if `getJDBCType` returns `(null, None)`, the default type handling is used
- * for the given Catalyst type.
- */
-private[sql] abstract class DriverQuirks {
-  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType
-  def getJDBCType(dt: DataType): (String, Option[Int])
-}
-
-private[sql] object DriverQuirks {
-  /**
-   * Fetch the DriverQuirks class corresponding to a given database url.
-   */
-  def get(url: String): DriverQuirks = {
-    if (url.startsWith("jdbc:mysql")) {
-      new MySQLQuirks()
-    } else if (url.startsWith("jdbc:postgresql")) {
-      new PostgresQuirks()
-    } else {
-      new NoQuirks()
-    }
-  }
-}
-
-private[sql] class NoQuirks extends DriverQuirks {
-  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
-    null
-  def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
-}
-
-private[sql] class PostgresQuirks extends DriverQuirks {
-  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
-    if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
-      BinaryType
-    } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
-      StringType
-    } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
-      StringType
-    } else null
-  }
-
-  def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
-    case StringType => ("TEXT", Some(java.sql.Types.CHAR))
-    case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
-    case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
-    case _ => (null, None)
-  }
-}
-
-private[sql] class MySQLQuirks extends DriverQuirks {
-  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
-    if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
-      // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
-      // byte arrays instead of longs.
-      md.putLong("binarylong", 1)
-      LongType
-    } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
-      BooleanType
-    } else null
-  }
-  def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 4189dfc..f7b1909 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -41,7 +41,7 @@ private[sql] object JDBCRDD extends Logging {
 
   /**
    * Maps a JDBC type to a Catalyst type.  This function is called only when
-   * the DriverQuirks class corresponding to your database driver returns null.
+   * the JdbcDialect class corresponding to your database driver returns null.
    *
    * @param sqlType - A field of java.sql.Types
    * @return The Catalyst type corresponding to sqlType.
@@ -51,7 +51,7 @@ private[sql] object JDBCRDD extends Logging {
       case java.sql.Types.ARRAY         => null
       case java.sql.Types.BIGINT        => LongType
       case java.sql.Types.BINARY        => BinaryType
-      case java.sql.Types.BIT           => BooleanType // Per JDBC; Quirks handles quirky drivers.
+      case java.sql.Types.BIT           => BooleanType // @see JdbcDialect for quirks
       case java.sql.Types.BLOB          => BinaryType
       case java.sql.Types.BOOLEAN       => BooleanType
       case java.sql.Types.CHAR          => StringType
@@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging {
    * @throws SQLException if the table contains an unsupported type.
    */
   def resolveTable(url: String, table: String, properties: Properties): StructType = {
-    val quirks = DriverQuirks.get(url)
+    val dialect = JdbcDialects.get(url)
     val conn: Connection = DriverManager.getConnection(url, properties)
     try {
       val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
@@ -125,8 +125,9 @@ private[sql] object JDBCRDD extends Logging {
           val fieldScale = rsmd.getScale(i + 1)
           val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
           val metadata = new MetadataBuilder().putString("name", columnName)
-          var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
-          if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale)
+          val columnType =
+            dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
+              getCatalystType(dataType, fieldSize, fieldScale))
           fields(i) = StructField(columnName, columnType, nullable, metadata.build())
           i = i + 1
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
new file mode 100644
index 0000000..6a169e1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.types._
+import org.apache.spark.annotation.DeveloperApi
+
+import java.sql.Types
+
+/**
+ * :: DeveloperApi ::
+ * A database type definition coupled with the jdbc type needed to send null
+ * values to the database.
+ * @param databaseTypeDefinition The database type definition
+ * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to
+ *                     send a null value to the database.
+ */
+@DeveloperApi
+case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
+
+/**
+ * :: DeveloperApi ::
+ * Encapsulates everything (extensions, workarounds, quirks) to handle the
+ * SQL dialect of a certain database or jdbc driver.
+ * Lots of databases define types that aren't explicitly supported
+ * by the JDBC spec.  Some JDBC drivers also report inaccurate
+ * information---for instance, BIT(n>1) being reported as a BIT type is quite
+ * common, even though BIT in JDBC is meant for single-bit values.  Also, there
+ * does not appear to be a standard name for an unbounded string or binary
+ * type; we use BLOB and CLOB by default but override with database-specific
+ * alternatives when these are absent or do not behave correctly.
+ *
+ * Currently, the only thing done by the dialect is type mapping.
+ * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
+ * is used when writing to a JDBC table.  If `getCatalystType` returns `null`,
+ * the default type handling is used for the given JDBC type.  Similarly,
+ * if `getJDBCType` returns `(null, None)`, the default type handling is used
+ * for the given Catalyst type.
+ */
+@DeveloperApi
+abstract class JdbcDialect {
+  /**
+   * Check if this dialect instance can handle a certain jdbc url.
+   * @param url the jdbc url.
+   * @return True if the dialect can be applied on the given jdbc url.
+   * @throws NullPointerException if the url is null.
+   */
+  def canHandle(url : String): Boolean
+
+  /**
+   * Get the custom datatype mapping for the given jdbc meta information.
+   * @param sqlType The sql type (see java.sql.Types)
+   * @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
+   * @param size The size of the type.
+   * @param md Result metadata associated with this type.
+   * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
+   *         or null if the default type mapping should be used.
+   */
+  def getCatalystType(
+    sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
+
+  /**
+   * Retrieve the jdbc / sql type for a given datatype.
+   * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
+   * @return The new JdbcType if there is an override for this DataType
+   */
+  def getJDBCType(dt: DataType): Option[JdbcType] = None
+}
+
+/**
+ * :: DeveloperApi ::
+ * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]].
+ *
+ * If multiple matching dialects are registered then all matching ones will be
+ * tried in reverse order. A user-added dialect will thus be applied first,
+ * overwriting the defaults.
+ *
+ * Note that all new dialects are applied to new jdbc DataFrames only. Make
+ * sure to register your dialects first.
+ */
+@DeveloperApi
+object JdbcDialects {
+
+  private var dialects = List[JdbcDialect]()
+
+  /**
+   * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]].
+   * Readding an existing dialect will cause a move-to-front.
+   * @param dialect The new dialect.
+   */
+  def registerDialect(dialect: JdbcDialect) : Unit = {
+    dialects = dialect :: dialects.filterNot(_ == dialect)
+  }
+
+  /**
+   * Unregister a dialect. Does nothing if the dialect is not registered.
+   * @param dialect The jdbc dialect.
+   */
+  def unregisterDialect(dialect : JdbcDialect) : Unit = {
+    dialects = dialects.filterNot(_ == dialect)
+  }
+
+  registerDialect(MySQLDialect)
+  registerDialect(PostgresDialect)
+
+  /**
+   * Fetch the JdbcDialect class corresponding to a given database url.
+   */
+  private[sql] def get(url: String): JdbcDialect = {
+    val matchingDialects = dialects.filter(_.canHandle(url))
+    matchingDialects.length match {
+      case 0 => NoopDialect
+      case 1 => matchingDialects.head
+      case _ => new AggregatedDialect(matchingDialects)
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * AggregatedDialect can unify multiple dialects into one virtual Dialect.
+ * Dialects are tried in order, and the first dialect that does not return a
+ * neutral element will will.
+ * @param dialects List of dialects.
+ */
+@DeveloperApi
+class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
+
+  require(!dialects.isEmpty)
+
+  def canHandle(url : String): Boolean =
+    dialects.map(_.canHandle(url)).reduce(_ && _)
+
+  override def getCatalystType(
+      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+    dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption
+
+  override def getJDBCType(dt: DataType): Option[JdbcType] =
+    dialects.map(_.getJDBCType(dt)).flatten.headOption
+
+}
+
+/**
+ * :: DeveloperApi ::
+ * NOOP dialect object, always returning the neutral element.
+ */
+@DeveloperApi
+case object NoopDialect extends JdbcDialect {
+  def canHandle(url : String): Boolean = true
+}
+
+/**
+ * :: DeveloperApi ::
+ * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write.
+ */
+@DeveloperApi
+case object PostgresDialect extends JdbcDialect {
+  def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
+  override def getCatalystType(
+      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
+    if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
+      Some(BinaryType)
+    } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
+      Some(StringType)
+    } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
+      Some(StringType)
+    } else None
+  }
+
+  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
+    case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
+    case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
+    case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
+    case _ => None
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Default mysql dialect to read bit/bitsets correctly.
+ */
+@DeveloperApi
+case object MySQLDialect extends JdbcDialect {
+  def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
+  override def getCatalystType(
+      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
+    if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
+      // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
+      // byte arrays instead of longs.
+      md.putLong("binarylong", 1)
+      Some(LongType)
+    } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
+      Some(BooleanType)
+    } else None
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index a61790b..f21dd29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -129,25 +129,26 @@ package object jdbc {
      */
     def schemaString(df: DataFrame, url: String): String = {
       val sb = new StringBuilder()
-      val quirks = DriverQuirks.get(url)
+      val dialect = JdbcDialects.get(url)
       df.schema.fields foreach { field => {
         val name = field.name
-        var typ: String = quirks.getJDBCType(field.dataType)._1
-        if (typ == null) typ = field.dataType match {
-          case IntegerType => "INTEGER"
-          case LongType => "BIGINT"
-          case DoubleType => "DOUBLE PRECISION"
-          case FloatType => "REAL"
-          case ShortType => "INTEGER"
-          case ByteType => "BYTE"
-          case BooleanType => "BIT(1)"
-          case StringType => "TEXT"
-          case BinaryType => "BLOB"
-          case TimestampType => "TIMESTAMP"
-          case DateType => "DATE"
-          case DecimalType.Unlimited => "DECIMAL(40,20)"
-          case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
-        }
+        val typ: String =
+          dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
+          field.dataType match {
+            case IntegerType => "INTEGER"
+            case LongType => "BIGINT"
+            case DoubleType => "DOUBLE PRECISION"
+            case FloatType => "REAL"
+            case ShortType => "INTEGER"
+            case ByteType => "BYTE"
+            case BooleanType => "BIT(1)"
+            case StringType => "TEXT"
+            case BinaryType => "BLOB"
+            case TimestampType => "TIMESTAMP"
+            case DateType => "DATE"
+            case DecimalType.Unlimited => "DECIMAL(40,20)"
+            case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+          })
         val nullable = if (field.nullable) "" else "NOT NULL"
         sb.append(s", $name $typ $nullable")
       }}
@@ -162,10 +163,9 @@ package object jdbc {
         url: String,
         table: String,
         properties: Properties = new Properties()) {
-      val quirks = DriverQuirks.get(url)
+      val dialect = JdbcDialects.get(url)
       val nullTypes: Array[Int] = df.schema.fields.map { field =>
-        val nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
-        if (nullType.isEmpty) {
+        dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
           field.dataType match {
             case IntegerType => java.sql.Types.INTEGER
             case LongType => java.sql.Types.BIGINT
@@ -181,8 +181,7 @@ package object jdbc {
             case DecimalType.Unlimited => java.sql.Types.DECIMAL
             case _ => throw new IllegalArgumentException(
               s"Can't translate null value for field $field")
-          }
-        } else nullType.get
+          })
       }
 
       val rddSchema = df.schema

http://git-wip-us.apache.org/repos/asf/spark/blob/e1ac2a95/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 5a7b6f0..a8dddfb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -35,6 +35,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
 
   val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
 
+  val testH2Dialect = new JdbcDialect {
+    def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+    override def getCatalystType(
+        sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+      Some(StringType)
+  }
+
   before {
     Class.forName("org.h2.Driver")
     // Extra properties that will be specified for our database. We need these to test
@@ -353,4 +360,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
         """.stripMargin.replaceAll("\n", " "))
     }
   }
+
+  test("Remap types via JdbcDialects") {
+    JdbcDialects.registerDialect(testH2Dialect)
+    val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+    assert(df.schema.filter(
+      _.dataType != org.apache.spark.sql.types.StringType
+    ).isEmpty)
+    val rows = df.collect()
+    assert(rows(0).get(0).isInstanceOf[String])
+    assert(rows(0).get(1).isInstanceOf[String])
+    JdbcDialects.unregisterDialect(testH2Dialect)
+  }
+
+  test("Default jdbc dialect registration") {
+    assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect)
+    assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect)
+    assert(JdbcDialects.get("test.invalid") == NoopDialect)
+  }
+
+  test("Dialect unregister") {
+    JdbcDialects.registerDialect(testH2Dialect)
+    JdbcDialects.unregisterDialect(testH2Dialect)
+    assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
+  }
+
+  test("Aggregated dialects") {
+    val agg = new AggregatedDialect(List(new JdbcDialect {
+      def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
+      override def getCatalystType(
+          sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+        if (sqlType % 2 == 0) {
+          Some(LongType)
+        } else {
+          None
+        }
+    }, testH2Dialect))
+    assert(agg.canHandle("jdbc:h2:xxx"))
+    assert(!agg.canHandle("jdbc:h2"))
+    assert(agg.getCatalystType(0,"",1,null) == Some(LongType))
+    assert(agg.getCatalystType(1,"",1,null) == Some(StringType))
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org