You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bahir.apache.org by lr...@apache.org on 2019/06/11 11:04:44 UTC
[bahir] branch master updated: [BAHIR-192] Add jdbc sink for
structured streaming. (#81)
This is an automated email from the ASF dual-hosted git repository.
lresende pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/bahir.git
The following commit(s) were added to refs/heads/master by this push:
new d1200cb [BAHIR-192] Add jdbc sink for structured streaming. (#81)
d1200cb is described below
commit d1200cb1cab57ac337b067443d47cef67d574fbd
Author: Wang Yanlin <32...@users.noreply.github.com>
AuthorDate: Tue Jun 11 19:04:39 2019 +0800
[BAHIR-192] Add jdbc sink for structured streaming. (#81)
---
pom.xml | 4 +
sql-streaming-jdbc/README.md | 81 ++++++++
.../sql/streaming/jdbc/JavaJdbcSinkDemo.java | 119 +++++++++++
.../examples/sql/streaming/jdbc/JdbcSinkDemo.scala | 83 ++++++++
sql-streaming-jdbc/pom.xml | 85 ++++++++
...org.apache.spark.sql.sources.DataSourceRegister | 18 ++
.../sql/streaming/jdbc/JdbcSourceProvider.scala | 40 ++++
.../sql/streaming/jdbc/JdbcStreamWriter.scala | 222 +++++++++++++++++++++
.../apache/bahir/sql/streaming/jdbc/JdbcUtil.scala | 119 +++++++++++
.../src/test/resources/log4j.properties | 27 +++
.../spark/sql/jdbc/JdbcStreamWriterSuite.scala | 193 ++++++++++++++++++
11 files changed, 991 insertions(+)
diff --git a/pom.xml b/pom.xml
index b55b0ac..788a4bb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -79,6 +79,7 @@
<module>sql-cloudant</module>
<module>sql-streaming-akka</module>
<module>sql-streaming-mqtt</module>
+ <module>sql-streaming-jdbc</module>
<module>streaming-akka</module>
<module>streaming-mqtt</module>
<module>streaming-pubnub</module>
@@ -421,6 +422,9 @@
<include>**/*.py</include>
</includes>
</resource>
+ <resource>
+ <directory>src/main/resources</directory>
+ </resource>
</resources>
<pluginManagement>
diff --git a/sql-streaming-jdbc/README.md b/sql-streaming-jdbc/README.md
new file mode 100644
index 0000000..e302adb
--- /dev/null
+++ b/sql-streaming-jdbc/README.md
@@ -0,0 +1,81 @@
+A library for writing data to jdbc using Spark SQL Streaming (or Structured streaming).
+
+## Linking
+
+Using SBT:
+
+ libraryDependencies += "org.apache.bahir" %% "spark-sql-streaming-jdbc" % "{{site.SPARK_VERSION}}"
+
+Using Maven:
+
+ <dependency>
+ <groupId>org.apache.bahir</groupId>
+ <artifactId>spark-sql-streaming-jdbc_{{site.SCALA_BINARY_VERSION}}</artifactId>
+ <version>{{site.SPARK_VERSION}}</version>
+ </dependency>
+
+This library can also be added to Spark jobs launched through `spark-shell` or `spark-submit` by using the `--packages` command line option.
+For example, to include it when starting the spark shell:
+
+ $ bin/spark-shell --packages org.apache.bahir:spark-sql-streaming-jdbc_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION}}
+
+Unlike using `--jars`, using `--packages` ensures that this library and its dependencies will be added to the classpath.
+The `--packages` argument can also be used with `bin/spark-submit`.
+
+This library is compiled for Scala 2.11 only, and intends to support Spark 2.0 onwards.
+
+## Configuration options
+The configuration is obtained from parameters.
+
+Name |Default | Meaning
+--- |:---:| ---
+url|required, no default value|jdbc url, like 'jdbc:mysql://127.0.0.1:3306/test?characterEncoding=UTF8'
+dbtable|required, no default value|table name
+driver|Attempts to locate a driver that understands the given URL by DriverManager, if driver parameter not specificed|full driver class name, like 'com.mysql.jdbc.Driver'
+user|None|username for database
+password|None|password for database
+batchsize|1000|records is batched writted to jdbc, to decrease jdbc pressure
+maxRetryNumber|4|max retry number before a task write to jdbc fails
+checkValidTimeoutSeconds|10|We cache a connection to avoid creating a new jdbc connection for each batch, timeout for checking connection valid
+
+## Examples
+
+### Scala API
+An example, for scala API to count words from incoming message stream.
+
+ // Create DataFrame from some stream source
+ val query = df.writeStream
+ .format("streaming-jdbc")
+ .option("checkpointLocation", "/path/to/localdir")
+ .outputMode("Append")
+ .option("url", "my jdbc url")
+ .option("dbtable", "myTableName")
+ .option("driver", "com.mysql.jdbc.Driver")
+ .option("user", "my database username")
+ .option("password", "my database password")
+ .trigger(Trigger.ProcessingTime("10 seconds"))
+ .start()
+
+ query.awaitTermination()
+
+Please see `JdbcSinkDemo.scala` for full example.
+
+### Java API
+An example, for Java API to count words from incoming message stream.
+
+ StreamingQuery query = result
+ .writeStream()
+ .outputMode("append")
+ .format("streaming-jdbc")
+ .outputMode(OutputMode.Append())
+ .option(JDBCOptions.JDBC_URL(), jdbcUrl)
+ .option(JDBCOptions.JDBC_TABLE_NAME(), tableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS(), "com.mysql.jdbc.Driver")
+ .option(JDBCOptions.JDBC_BATCH_INSERT_SIZE(), "5")
+ .option("user", username)
+ .option("password", password)
+ .trigger(Trigger.ProcessingTime("10 seconds"))
+ .start();
+ query.awaitTermination();
+
+Please see `JavaJdbcSinkDemo.java` for full example.
diff --git a/sql-streaming-jdbc/examples/src/main/java/org/apache/bahir/examples/sql/streaming/jdbc/JavaJdbcSinkDemo.java b/sql-streaming-jdbc/examples/src/main/java/org/apache/bahir/examples/sql/streaming/jdbc/JavaJdbcSinkDemo.java
new file mode 100644
index 0000000..28844c4
--- /dev/null
+++ b/sql-streaming-jdbc/examples/src/main/java/org/apache/bahir/examples/sql/streaming/jdbc/JavaJdbcSinkDemo.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.examples.sql.streaming.jdbc;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions;
+import org.apache.spark.sql.streaming.OutputMode;
+import org.apache.spark.sql.streaming.StreamingQuery;
+import org.apache.spark.sql.streaming.Trigger;
+
+/**
+ * Mock using rate source, change the log to a simple Person
+ * object with name and age property, and write to jdbc.
+ *
+ * Usage: JdbcSinkDemo <jdbcUrl> <tableName> <username> <password>
+ */
+public class JavaJdbcSinkDemo {
+
+ public static void main(String[] args) throws Exception{
+ if (args.length < 4) {
+ System.err.println("Usage: JdbcSinkDemo <jdbcUrl> <tableName> <username> <password>");
+ System.exit(1);
+ }
+
+ String jdbcUrl = args[0];
+ String tableName = args[1];
+ String username = args[2];
+ String password = args[3];
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaJdbcSinkDemo");
+
+ SparkSession spark = SparkSession.builder()
+ .config(sparkConf)
+ .getOrCreate();
+
+ // load data source
+ Dataset<Long> lines = spark
+ .readStream()
+ .format("rate")
+ .option("numPartitions", "5")
+ .option("rowsPerSecond", "100")
+ .load().select("value").as(Encoders.LONG());
+ // change input value to a person object.
+ DemoMapFunction demoFunction = new DemoMapFunction();
+ Dataset<Person> result = lines.map(demoFunction, Encoders.javaSerialization(Person.class));
+
+ // print schema for debug
+ result.printSchema();
+
+ StreamingQuery query = result
+ .writeStream()
+ .outputMode("append")
+ .format("streaming-jdbc")
+ .outputMode(OutputMode.Append())
+ .option(JDBCOptions.JDBC_URL(), jdbcUrl)
+ .option(JDBCOptions.JDBC_TABLE_NAME(), tableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS(), "com.mysql.jdbc.Driver")
+ .option(JDBCOptions.JDBC_BATCH_INSERT_SIZE(), "5")
+ .option("user", username)
+ .option("password", password)
+ .trigger(Trigger.ProcessingTime("10 seconds"))
+ .start();
+ query.awaitTermination();
+
+ }
+
+ private static class Person {
+ private String name;
+ private int age;
+
+ Person(String name, int age) {
+ this.name = name;
+ this.age = age;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public int getAge() {
+ return age;
+ }
+
+ public void setAge(int age) {
+ this.age = age;
+ }
+ }
+
+ private static class DemoMapFunction implements MapFunction<Long, Person> {
+
+ @Override
+ public Person call(Long value) throws Exception {
+ return new Person("name_" + value, value.intValue() % 30);
+ }
+ }
+}
diff --git a/sql-streaming-jdbc/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/jdbc/JdbcSinkDemo.scala b/sql-streaming-jdbc/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/jdbc/JdbcSinkDemo.scala
new file mode 100644
index 0000000..ff51909
--- /dev/null
+++ b/sql-streaming-jdbc/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/jdbc/JdbcSinkDemo.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.examples.sql.streaming.jdbc
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+
+/**
+ * Mock using rate source, change the log to a simple Person
+ * object with name and age property, and write to jdbc.
+ *
+ * Usage: JdbcSinkDemo <jdbcUrl> <tableName> <username> <password>
+ */
+object JdbcSinkDemo {
+
+ private case class Person(name: String, age: Int)
+
+ def main(args: Array[String]): Unit = {
+ if (args.length < 4) {
+ // scalastyle:off println
+ System.err.println("Usage: JdbcSinkDemo <jdbcUrl> <tableName> <username> <password>")
+ // scalastyle:on
+ System.exit(1)
+ }
+
+ val jdbcUrl = args(0)
+ val tableName = args(1)
+ val username = args(2)
+ val password = args(3)
+
+ val spark = SparkSession
+ .builder()
+ .appName("JdbcSinkDemo")
+ .getOrCreate()
+
+ // load data source
+ val df = spark.readStream
+ .format("rate")
+ .option("numPartitions", "5")
+ .option("rowsPerSecond", "100")
+ .load()
+
+ // change input value to a person object.
+ import spark.implicits._
+ val lines = df.select("value").as[Long].map{ value =>
+ Person(s"name_${value}", value.toInt % 30)
+ }
+
+ lines.printSchema()
+
+ // write result
+ val query = lines.writeStream
+ .outputMode("append")
+ .format("streaming-jdbc")
+ .outputMode(OutputMode.Append)
+ .option(JDBCOptions.JDBC_URL, jdbcUrl)
+ .option(JDBCOptions.JDBC_TABLE_NAME, tableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, "com.mysql.jdbc.Driver")
+ .option(JDBCOptions.JDBC_BATCH_INSERT_SIZE, "5")
+ .option("user", username)
+ .option("password", password)
+ .trigger(Trigger.ProcessingTime("10 seconds"))
+ .start()
+
+ query.awaitTermination()
+ }
+}
diff --git a/sql-streaming-jdbc/pom.xml b/sql-streaming-jdbc/pom.xml
new file mode 100644
index 0000000..70cf7b6
--- /dev/null
+++ b/sql-streaming-jdbc/pom.xml
@@ -0,0 +1,85 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.bahir</groupId>
+ <artifactId>bahir-parent_2.12</artifactId>
+ <version>2.4.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.bahir</groupId>
+ <artifactId>spark-sql-streaming-jdbc_2.12</artifactId>
+ <properties>
+ <sbt.project.name>sql-streaming-jdbc</sbt.project.name>
+ </properties>
+ <packaging>jar</packaging>
+ <name>Apache Bahir - Spark SQL Streaming JDBC</name>
+ <url>http://bahir.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.bahir</groupId>
+ <artifactId>bahir-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-tags_${scala.binary.version}</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.h2database</groupId>
+ <artifactId>h2</artifactId>
+ <version>1.4.195</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/sql-streaming-jdbc/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-jdbc/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000..a27a8f9
--- /dev/null
+++ b/sql-streaming-jdbc/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+org.apache.bahir.sql.streaming.jdbc.JdbcSourceProvider
\ No newline at end of file
diff --git a/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcSourceProvider.scala b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcSourceProvider.scala
new file mode 100644
index 0000000..b256c17
--- /dev/null
+++ b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcSourceProvider.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.jdbc
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+class JdbcSourceProvider extends StreamWriteSupport with DataSourceRegister{
+ override def createStreamWriter(queryId: String, schema: StructType,
+ mode: OutputMode, options: DataSourceOptions): StreamWriter = {
+ val optionMap = options.asMap().asScala.toMap
+ // add this for parameter check.
+ new JDBCOptions(optionMap)
+ new JdbcStreamWriter(schema, optionMap)
+ }
+
+ // short name 'jdbc' is used for batch, chose a different name for streaming.
+ override def shortName(): String = "streaming-jdbc"
+}
diff --git a/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcStreamWriter.scala b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcStreamWriter.scala
new file mode 100644
index 0000000..31b47b6
--- /dev/null
+++ b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcStreamWriter.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.jdbc
+
+import java.sql.{Connection, PreparedStatement, SQLException}
+import java.util.Locale
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.jdbc.JdbcDialects
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.types.StructType
+
+import org.apache.bahir.utils.Logging
+
+/**
+ * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
+ * don't need to really send one.
+ */
+case object JdbcWriterCommitMessage extends WriterCommitMessage
+/**
+ * A [[org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter]] for jdbc writing.
+ * Responsible for generating the writer factory.
+ */
+class JdbcStreamWriter(
+ schema: StructType,
+ options: Map[String, String]
+) extends StreamWriter with Logging {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ log.info(s"epoch ${epochId} of JdbcStreamWriter commited!")
+ }
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ log.info(s"epoch ${epochId} of JdbcStreamWriter aborted!")
+ }
+
+ override def createWriterFactory(): DataWriterFactory[InternalRow] = {
+ new JdbcStreamWriterFactory(schema, options)
+ }
+}
+/**
+ * A [[DataWriterFactory]] for jdbc writing.
+ * Will be serialized and sent to executors to generate the per-task data writers.
+ */
+case class JdbcStreamWriterFactory(
+ schema: StructType,
+ options: Map[String, String]
+) extends DataWriterFactory[InternalRow] with Logging {
+ override def createDataWriter(
+ partitionId: Int,
+ taskId: Long,
+ epochId: Long): DataWriter[InternalRow] = {
+ log.info(s"Create date writer for TID ${taskId}, EpochId ${epochId}")
+ JdbcStreamDataWriter(schema, options)
+ }
+}
+/**
+ * A [[org.apache.spark.sql.sources.v2.writer.DataWriter]] for Jdbc writing.
+ * One data writer will be created in each partition to process incoming rows.
+ */
+case class JdbcStreamDataWriter(
+ schema: StructType,
+ options: Map[String, String]
+) extends DataWriter[InternalRow] with Logging {
+ private val jdbcOptions = new JDBCOptions(options)
+
+ // use a local cache for batch write to jdbc.
+ private val batchSize = jdbcOptions.batchSize
+ private val localBuffer = new ArrayBuffer[Row](batchSize)
+ private val maxRetryNum = options.getOrElse("maxRetryNumber", "4").toInt
+ private val checkValidTimeoutSeconds =
+ options.getOrElse("checkValidTimeoutSeconds", "10").toInt
+
+ // the first part is the column name list, the second part is the placeholder string.
+ private val sqlPart: (String, String) = {
+ val columnListBuilder = new StringBuilder()
+ val holderListBuilder = new StringBuilder()
+ schema.fields.foreach { field =>
+ columnListBuilder.append(",").append(field.name)
+ holderListBuilder.append(",?")
+ }
+ (columnListBuilder.substring(1), holderListBuilder.substring(1))
+ }
+
+ private val sql = s"REPLACE INTO ${jdbcOptions.tableOrQuery} " +
+ s"( ${sqlPart._1} ) values ( ${sqlPart._2} )"
+ log.trace(s"Sql string for jdbc writing is ${sql}")
+ private val dialect = JdbcDialects.get(jdbcOptions.url)
+ // used for batch writing.
+ private var conn: Connection = _
+ private var stmt: PreparedStatement = _
+
+ checkSchema()
+ private val setters = schema.fields.map { f =>
+ resetConnectionAndStmt()
+ JdbcUtil.makeSetter(conn, dialect, f.dataType)
+ }
+ private val numFields = schema.fields.length
+ private val nullTypes = schema.fields.map(f =>
+ JdbcUtil.getJdbcType(f.dataType, dialect).jdbcNullType)
+ /**
+ * Check data schema with table.
+ * Data schema should equal with table schema or is a subset of table schema,
+ * and the column type with the same name in data schema and table scheme should be the same.
+ */
+ private def checkSchema(): Unit = {
+ resetConnectionAndStmt()
+ val tableSchemaMap = JdbcUtils.getSchemaOption(conn, jdbcOptions) match {
+ case Some(tableSchema) =>
+ log.info(s"Get table ${jdbcOptions.tableOrQuery}'s schema $tableSchema")
+ tableSchema.fields.map(field => field.name.toLowerCase(Locale.ROOT) -> field).toMap
+ case _ => throw new IllegalStateException(
+ s"Schema of table ${jdbcOptions.tableOrQuery} is not defined, make sure table exist!")
+ }
+ schema.map { field =>
+ val tableColumn = tableSchemaMap.get(field.name.toLowerCase(Locale.ROOT))
+ assert(tableColumn.isDefined,
+ s"Data column ${field.name} cannot be found in table ${jdbcOptions.tableOrQuery}")
+ assert(field.dataType == tableColumn.get.dataType,
+ s"Type of data column ${field.name} is not the same in table ${jdbcOptions.tableOrQuery}")
+ }
+ }
+ // Using a local connection cache, avoid getting a new connection every time.
+ private def resetConnectionAndStmt(): Unit = {
+ if (conn == null || !conn.isValid(checkValidTimeoutSeconds)) {
+ conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
+ stmt = conn.prepareStatement(sql)
+ log.info("Current connection is invalid, create a new one.")
+ } else {
+ log.debug("Current connection is valid, reuse it.")
+ }
+ }
+
+ override def write(record: InternalRow): Unit = {
+ localBuffer.append(Row.fromSeq(record.copy().toSeq(schema)))
+ if (localBuffer.size == batchSize) {
+ log.debug(s"Local buffer is full with size $batchSize, do write and reset local buffer.")
+ doWriteAndResetBuffer()
+ }
+ }
+ // batch write to jdbc, retry for SQLException
+ private def doWriteAndResetBuffer(): Unit = {
+ var tryNum = 0
+ val size = localBuffer.size
+ while (tryNum <= maxRetryNum) {
+ try {
+ val start = System.currentTimeMillis()
+ val iterator = localBuffer.iterator
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ stmt.setNull(i + 1, nullTypes(i))
+ } else {
+ setters(i).apply(stmt, row, i)
+ }
+ i += 1
+ }
+ stmt.addBatch()
+ }
+ stmt.executeBatch()
+ localBuffer.clear()
+ log.debug(s"Success write $size records,"
+ + s"retry number $tryNum, cost ${System.currentTimeMillis() - start} ms")
+ tryNum = maxRetryNum + 1
+ } catch {
+ case e: SQLException =>
+ if (tryNum <= maxRetryNum) {
+ tryNum += 1
+ resetConnectionAndStmt()
+ log.warn(s"Failed to write $size records, retry number $tryNum!", e)
+ } else {
+ log.error(s"Failed to write $size records,"
+ + s"reach max retry number $maxRetryNum, abort writing!")
+ throw e
+ }
+ case e: Throwable =>
+ log.error(s"Failed to write $size records, not suited for retry , abort writing!", e)
+ throw e
+ }
+ }
+ }
+
+ private def doWriteAndClose(): Unit = {
+ if (localBuffer.nonEmpty) {
+ doWriteAndResetBuffer()
+ }
+ if (conn != null) {
+ try {
+ conn.close()
+ } catch {
+ case e: Throwable => log.error("Close connection with exception", e)
+ }
+ }
+ }
+ override def commit(): WriterCommitMessage = {
+ doWriteAndClose()
+ JdbcWriterCommitMessage
+ }
+ override def abort(): Unit = {
+ log.info(s"Abort writing with ${localBuffer.size} records in local buffer.")
+ }
+}
diff --git a/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcUtil.scala b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcUtil.scala
new file mode 100644
index 0000000..7184db9
--- /dev/null
+++ b/sql-streaming-jdbc/src/main/scala/org/apache/bahir/sql/streaming/jdbc/JdbcUtil.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.jdbc
+
+import java.sql.{Connection, PreparedStatement}
+import java.util.Locale
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Util functions for JDBC tables.
+ * As the access privilege of `JdbcUtils.makeSetter` and `JdbcUtils.getJdbcType`is private,
+ * so we rewrite `makeSetter` and `getJdbcType`,
+ * if access privilege for `JdbcUtils.makeSetter` and `JdbcUtils.getJdbcType` changes later,
+ * this `JdbcUtil` object can be removed.
+ */
+object JdbcUtil {
+
+ def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
+ dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse(
+ throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
+ }
+
+ // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
+ // `PreparedStatement`. The last argument `Int` means the index for the value to be set
+ // in the SQL statement and also used for the value in `Row`.
+ type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
+
+ def makeSetter(
+ conn: Connection,
+ dialect: JdbcDialect,
+ dataType: DataType): JDBCValueSetter = dataType match {
+ case IntegerType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getInt(pos))
+
+ case LongType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setLong(pos + 1, row.getLong(pos))
+
+ case DoubleType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setDouble(pos + 1, row.getDouble(pos))
+
+ case FloatType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setFloat(pos + 1, row.getFloat(pos))
+
+ case ShortType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getShort(pos))
+
+ case ByteType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getByte(pos))
+
+ case BooleanType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBoolean(pos + 1, row.getBoolean(pos))
+
+ case StringType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ val strValue = row.get(pos) match {
+ case str: UTF8String => str.toString
+ case str: String => str
+ }
+ stmt.setString(pos + 1, strValue)
+
+ case BinaryType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
+
+ case TimestampType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
+
+ case DateType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
+
+ case t: DecimalType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
+
+ case ArrayType(et, _) =>
+ // remove type length parameters from end of type name
+ val typeName = getJdbcType(et, dialect).databaseTypeDefinition
+ .toLowerCase(Locale.ROOT).split("\\(")(0)
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ val array = conn.createArrayOf(
+ typeName,
+ row.getSeq[AnyRef](pos).toArray)
+ stmt.setArray(pos + 1, array)
+
+ case _ =>
+ (_: PreparedStatement, _: Row, pos: Int) =>
+ throw new IllegalArgumentException(
+ s"Can't translate non-null value for field $pos")
+ }
+}
diff --git a/sql-streaming-jdbc/src/test/resources/log4j.properties b/sql-streaming-jdbc/src/test/resources/log4j.properties
new file mode 100644
index 0000000..3706a6e
--- /dev/null
+++ b/sql-streaming-jdbc/src/test/resources/log4j.properties
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/sql-streaming-jdbc/src/test/scala/org/apache/spark/sql/jdbc/JdbcStreamWriterSuite.scala b/sql-streaming-jdbc/src/test/scala/org/apache/spark/sql/jdbc/JdbcStreamWriterSuite.scala
new file mode 100644
index 0000000..53b0ac5
--- /dev/null
+++ b/sql-streaming-jdbc/src/test/scala/org/apache/spark/sql/jdbc/JdbcStreamWriterSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.DriverManager
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
+import org.apache.spark.util.Utils
+
+private case class TestData(name: String, value: Long)
+
+class JdbcStreamWriteSuite extends StreamTest with BeforeAndAfter{
+ import testImplicits._
+
+ val url = "jdbc:h2:mem:testdb"
+ val jdbcTableName = "stream_test_table"
+ val driverClassName = "org.h2.Driver"
+ val createTableSql = s"""
+ |CREATE TABLE ${jdbcTableName}(
+ | name VARCHAR(32),
+ | value LONG,
+ | PRIMARY KEY (name)
+ |)""".stripMargin
+
+ var conn: java.sql.Connection = null
+
+ val testH2Dialect = new JdbcDialect {
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
+ }
+
+ before {
+ Utils.classForName(driverClassName)
+ conn = DriverManager.getConnection(url)
+ conn.prepareStatement(createTableSql).executeUpdate()
+ }
+
+ after {
+ conn.close()
+ }
+
+ test("Basic Write") {
+ withTempDir { checkpointDir => {
+ val input = MemoryStream[Int]
+ val query = input.toDF().map { row =>
+ val value = row.getInt(0)
+ TestData(s"name_$value", value.toLong)
+ }.writeStream
+ .format("streaming-jdbc")
+ .option(JDBCOptions.JDBC_URL, url)
+ .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName)
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ input.addData(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ val result = conn
+ .prepareStatement(s"select count(*) as count from $jdbcTableName")
+ .executeQuery()
+ assert(result.next())
+ assert(result.getInt("count") == 10)
+ }
+
+ test("Write sub columns") {
+ withTempDir { checkpointDir => {
+ val input = MemoryStream[Int]
+ val query = input.toDF().map { row =>
+ val value = row.getInt(0)
+ TestData(s"name_$value", value.toLong)
+ }.select("name").writeStream // write just one `name` column
+ .format("streaming-jdbc")
+ .option(JDBCOptions.JDBC_URL, url)
+ .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName)
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ input.addData(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ val result = conn
+ .prepareStatement(s"select count(*) as count from $jdbcTableName")
+ .executeQuery()
+ assert(result.next())
+ assert(result.getInt("count") == 10)
+ }
+
+ test("Write same data") {
+ withTempDir { checkpointDir => {
+ val input = MemoryStream[Int]
+ val query = input.toDF().map { row =>
+ val value = row.getInt(0)
+ TestData(s"name_$value", value.toLong)
+ }.writeStream
+ .format("streaming-jdbc")
+ .option(JDBCOptions.JDBC_URL, url)
+ .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName)
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ input.addData(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ val result = conn
+ .prepareStatement(s"select count(*) as count from $jdbcTableName")
+ .executeQuery()
+ assert(result.next())
+ assert(result.getInt("count") == 1)
+ }
+
+ test("Write without required parameter") {
+ // without jdbc url
+ val thrown = intercept[StreamingQueryException] {
+ withTempDir { checkpointDir => {
+ val input = MemoryStream[Int]
+ val query = input.toDF().map { row =>
+ val value = row.getInt(0)
+ TestData(s"name_$value", value.toLong)
+ }.writeStream
+ .format("streaming-jdbc")
+ .option(JDBCOptions.JDBC_TABLE_NAME, jdbcTableName)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName)
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+ assert(thrown.getMessage.contains("requirement failed: Option 'url' is required."))
+ // without table name
+ val thrown2 = intercept[StreamingQueryException] {
+ withTempDir { checkpointDir => {
+ val input = MemoryStream[Int]
+ val query = input.toDF().map { row =>
+ val value = row.getInt(0)
+ TestData(s"name_$value", value.toLong)
+ }.writeStream
+ .format("streaming-jdbc")
+ .option(JDBCOptions.JDBC_URL, url)
+ .option(JDBCOptions.JDBC_DRIVER_CLASS, driverClassName)
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+ assert(thrown2.getMessage.contains("Option 'dbtable' or 'query' is required"))
+ }
+}